use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use super::error::ProviderResult;
use super::types::{CompletionResponse, Message, StreamEvent};
use crate::tools::tool::ToolDefinition;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub system_prompt: String,
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub max_request_tokens: Option<u32>,
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolChoice {
Auto,
Specific { name: String },
}
pub trait Provider: Send + Sync {
fn complete(
&self,
request: CompletionRequest,
) -> Pin<Box<dyn Future<Output = ProviderResult<CompletionResponse>> + Send + '_>>;
fn complete_streaming(
&self,
request: CompletionRequest,
on_event: Arc<dyn Fn(StreamEvent) + Send + Sync>,
) -> Pin<Box<dyn Future<Output = ProviderResult<CompletionResponse>> + Send + '_>> {
Box::pin(async move {
let response = self.complete(request).await?;
on_event(StreamEvent::MessageDone);
Ok(response)
})
}
fn prewarm(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async {})
}
}
pub(crate) fn default_client() -> reqwest::Client {
reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.read_timeout(std::time::Duration::from_secs(3))
.build()
.expect("reqwest::Client with timeouts should build")
}
pub(crate) async fn prewarm_with(client: &reqwest::Client, base_url: &str) {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(10),
client.head(base_url).send(),
)
.await;
}