use crate::config::ModelEndpoint;
use crate::error::{AppError, AppResult};
use open_agent::{AgentOptions, Client};
use std::sync::Arc;
use tokio::sync::Mutex;
#[allow(dead_code)]
pub struct ModelClient {
endpoint: ModelEndpoint,
client: Arc<Mutex<Client>>,
}
impl ModelClient {
#[allow(dead_code)]
pub fn new(endpoint: ModelEndpoint) -> AppResult<Self> {
let options = AgentOptions::builder()
.model(endpoint.name())
.base_url(endpoint.base_url())
.max_tokens(endpoint.max_tokens() as u32)
.temperature(endpoint.temperature() as f32)
.build()
.map_err(|e| AppError::Internal(format!("Failed to build AgentOptions: {}", e)))?;
let client = Client::new(options).map_err(|e| {
AppError::Internal(format!(
"Failed to create client for {}: {}",
endpoint.name(),
e
))
})?;
Ok(Self {
endpoint,
client: Arc::new(Mutex::new(client)),
})
}
#[allow(dead_code)]
pub fn endpoint(&self) -> &ModelEndpoint {
&self.endpoint
}
#[allow(dead_code)]
pub fn client(&self) -> &Arc<Mutex<Client>> {
&self.client
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_endpoint_deserialization() {
let json = r#"{
"name": "test-model",
"base_url": "http://localhost:1234/v1",
"max_tokens": 2048,
"temperature": 0.7,
"weight": 1.0,
"priority": 1
}"#;
let endpoint: ModelEndpoint =
serde_json::from_str(json).expect("should deserialize ModelEndpoint");
assert_eq!(endpoint.name(), "test-model");
assert_eq!(endpoint.base_url(), "http://localhost:1234/v1");
assert_eq!(endpoint.max_tokens(), 2048);
assert_eq!(endpoint.temperature(), 0.7);
assert_eq!(endpoint.weight(), 1.0);
assert_eq!(endpoint.priority(), 1);
}
}