use super::TeacherProvider;
use super::security::EndpointSecurity;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TeacherConfig {
pub provider: TeacherProvider,
pub model_id: String,
pub endpoint: Option<String>,
pub api_key_env: String,
pub temperature: f32,
pub max_tokens: usize,
pub timeout_ms: u64,
pub max_retries: usize,
pub retry_delay_ms: u64,
pub system_prompt: Option<String>,
pub security: EndpointSecurity,
}
impl Default for TeacherConfig {
fn default() -> Self {
Self::claude_sonnet()
}
}
impl TeacherConfig {
pub fn claude_sonnet() -> Self {
Self {
provider: TeacherProvider::Claude,
model_id: "claude-sonnet-4-20250514".to_string(),
endpoint: None,
api_key_env: "ANTHROPIC_API_KEY".to_string(),
temperature: 0.3,
max_tokens: 1024,
timeout_ms: 30000,
max_retries: 3,
retry_delay_ms: 1000,
system_prompt: Some(Self::default_system_prompt()),
security: EndpointSecurity::default_secure(),
}
}
pub fn claude_haiku() -> Self {
Self {
provider: TeacherProvider::Claude,
model_id: "claude-3-5-haiku-20241022".to_string(),
endpoint: None,
api_key_env: "ANTHROPIC_API_KEY".to_string(),
temperature: 0.3,
max_tokens: 1024,
timeout_ms: 15000,
max_retries: 3,
retry_delay_ms: 500,
system_prompt: Some(Self::default_system_prompt()),
security: EndpointSecurity::default_secure(),
}
}
pub fn gpt4() -> Self {
Self {
provider: TeacherProvider::OpenAI,
model_id: "gpt-4-turbo-preview".to_string(),
endpoint: None,
api_key_env: "OPENAI_API_KEY".to_string(),
temperature: 0.3,
max_tokens: 1024,
timeout_ms: 30000,
max_retries: 3,
retry_delay_ms: 1000,
system_prompt: Some(Self::default_system_prompt()),
security: EndpointSecurity::default_secure(),
}
}
pub fn gemini_pro() -> Self {
Self {
provider: TeacherProvider::Gemini,
model_id: "gemini-pro".to_string(),
endpoint: None,
api_key_env: "GOOGLE_API_KEY".to_string(),
temperature: 0.3,
max_tokens: 1024,
timeout_ms: 30000,
max_retries: 3,
retry_delay_ms: 1000,
system_prompt: Some(Self::default_system_prompt()),
security: EndpointSecurity::default_secure(),
}
}
pub fn local(model_id: impl Into<String>, endpoint: impl Into<String>) -> Self {
Self {
provider: TeacherProvider::Local,
model_id: model_id.into(),
endpoint: Some(endpoint.into()),
api_key_env: String::new(),
temperature: 0.3,
max_tokens: 1024,
timeout_ms: 60000,
max_retries: 2,
retry_delay_ms: 500,
system_prompt: Some(Self::default_system_prompt()),
security: EndpointSecurity::for_local(),
}
}
pub fn builder() -> TeacherConfigBuilder {
TeacherConfigBuilder::new()
}
pub fn default_system_prompt() -> String {
r#"You are an intent classification system. Given a conversation context and a new message, classify the intent of the new message.
Output a JSON object with probability scores for each intent class. Scores should sum to approximately 1.0.
Intent classes:
- continuation: Natural conversation continuation
- topic_shift: User is changing the topic
- explicit_query: Direct question or request for information
- person_lookup: Looking up information about a person/contact
- health_check: Health or wellness related inquiry
- task_status: Checking on task or todo status
Example output:
{"continuation": 0.7, "topic_shift": 0.1, "explicit_query": 0.15, "person_lookup": 0.02, "health_check": 0.02, "task_status": 0.01}
Respond ONLY with the JSON object, no other text."#.to_string()
}
pub fn validate(&self) -> Result<(), String> {
if self.model_id.is_empty() {
return Err("model_id cannot be empty".to_string());
}
if !(0.0..=2.0).contains(&self.temperature) {
return Err(format!(
"temperature must be between 0.0 and 2.0, got {}",
self.temperature
));
}
if self.max_tokens == 0 {
return Err("max_tokens must be > 0".to_string());
}
if self.timeout_ms == 0 {
return Err("timeout_ms must be > 0".to_string());
}
if let Some(ref endpoint) = self.endpoint {
self.security.verify_endpoint(endpoint)?;
}
Ok(())
}
pub fn verify_endpoint(&self) -> Result<(), String> {
let endpoint = self.get_endpoint();
self.security.verify_endpoint(&endpoint)?;
self.security.validate_cert_fingerprint("")?;
Ok(())
}
pub fn get_endpoint(&self) -> String {
if let Some(ref ep) = self.endpoint {
ep.clone()
} else {
match self.provider {
TeacherProvider::Claude => "https://api.anthropic.com/v1".to_string(),
TeacherProvider::OpenAI => "https://api.openai.com/v1".to_string(),
TeacherProvider::Gemini => {
"https://generativelanguage.googleapis.com/v1".to_string()
}
TeacherProvider::Local => "http://localhost:11434".to_string(),
TeacherProvider::Custom(_) => "".to_string(),
}
}
}
pub fn with_security(mut self, security: EndpointSecurity) -> Self {
self.security = security;
self
}
pub fn display_name(&self) -> String {
format!("{}:{}", self.provider, self.model_id)
}
}
#[derive(Debug, Clone)]
pub struct TeacherConfigBuilder {
config: TeacherConfig,
}
impl TeacherConfigBuilder {
pub fn new() -> Self {
Self {
config: TeacherConfig::claude_sonnet(),
}
}
pub fn provider(mut self, provider: TeacherProvider) -> Self {
self.config.provider = provider;
self
}
pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
self.config.model_id = model_id.into();
self
}
pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.config.endpoint = Some(endpoint.into());
self
}
pub fn api_key_env(mut self, env_var: impl Into<String>) -> Self {
self.config.api_key_env = env_var.into();
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.config.temperature = temp;
self
}
pub fn max_tokens(mut self, tokens: usize) -> Self {
self.config.max_tokens = tokens;
self
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.config.timeout_ms = ms;
self
}
pub fn max_retries(mut self, retries: usize) -> Self {
self.config.max_retries = retries;
self
}
pub fn retry_delay_ms(mut self, ms: u64) -> Self {
self.config.retry_delay_ms = ms;
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.config.system_prompt = Some(prompt.into());
self
}
pub fn no_system_prompt(mut self) -> Self {
self.config.system_prompt = None;
self
}
pub fn security(mut self, security: EndpointSecurity) -> Self {
self.config.security = security;
self
}
pub fn build(self) -> TeacherConfig {
self.config
}
}
impl Default for TeacherConfigBuilder {
fn default() -> Self {
Self::new()
}
}