use crate::core::types::health::HealthStatus;
use crate::core::types::{chat::ChatRequest, context::RequestContext, responses::ChatResponse};
use crate::utils::error::gateway_error::GatewayError;
use super::llm_provider::trait_definition::LLMProvider;
pub struct ProviderHandle {
name: String,
_provider: std::sync::Arc<dyn std::any::Any + Send + Sync>,
weight: f64,
enabled: bool,
}
impl ProviderHandle {
pub fn new<P>(provider: P, weight: f64) -> Self
where
P: LLMProvider + Send + Sync + 'static,
{
Self {
name: provider.name().to_string(),
_provider: std::sync::Arc::new(provider)
as std::sync::Arc<dyn std::any::Any + Send + Sync>,
weight,
enabled: true,
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn weight(&self) -> f64 {
self.weight
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub async fn chat_completion(
&self,
_request: ChatRequest,
_context: RequestContext,
) -> Result<ChatResponse, GatewayError> {
Err(GatewayError::Internal(
"Provider chat_completion not implemented".to_string(),
))
}
pub fn supports_model(&self, _model: &str) -> bool {
true
}
pub fn supports_tools(&self) -> bool {
true
}
pub async fn health_check(&self) -> HealthStatus {
HealthStatus::Healthy
}
pub async fn calculate_cost(
&self,
_model: &str,
_input: u32,
_output: u32,
) -> Result<f64, GatewayError> {
Ok(0.0)
}
pub async fn get_average_latency(&self) -> Result<std::time::Duration, GatewayError> {
Ok(std::time::Duration::from_millis(100))
}
pub async fn get_success_rate(&self) -> Result<f32, GatewayError> {
Ok(1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_handle_struct_exists() {
let _type_check: fn() -> bool = || {
true
};
assert!(_type_check());
}
#[tokio::test]
async fn test_health_status_healthy() {
let status = HealthStatus::Healthy;
assert!(matches!(status, HealthStatus::Healthy));
}
#[tokio::test]
async fn test_health_status_unhealthy() {
let status = HealthStatus::Unhealthy;
assert!(matches!(status, HealthStatus::Unhealthy));
}
#[tokio::test]
async fn test_health_status_degraded() {
let status = HealthStatus::Degraded;
assert!(matches!(status, HealthStatus::Degraded));
}
#[test]
fn test_request_context_default() {
let context = RequestContext::default();
assert!(!context.request_id.is_empty());
}
#[test]
fn test_chat_request_default() {
let request = ChatRequest {
model: "test-model".to_string(),
messages: vec![],
..Default::default()
};
assert_eq!(request.model, "test-model");
assert!(request.messages.is_empty());
}
}