use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::error::ProviderResult;
use super::types::{Message, ModelResponse, StreamEvent};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderToolDefinition {
pub name: String,
pub description: String,
pub input_schema: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRequest {
pub model: String,
pub system_prompt: String,
pub messages: Vec<Message>,
pub tools: Vec<ProviderToolDefinition>,
pub max_request_tokens: Option<u32>,
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolChoice {
Auto,
Specific { name: String },
}
pub trait Provider: Send + Sync {
fn respond(
&self,
request: ModelRequest,
on_event: Arc<dyn Fn(StreamEvent) + Send + Sync>,
) -> Pin<Box<dyn Future<Output = ProviderResult<ModelResponse>> + Send + '_>>;
fn prewarm(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async {})
}
}
pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(600);
pub(crate) fn build_client(timeout: Duration) -> reqwest::Client {
reqwest::Client::builder()
.timeout(timeout)
.build()
.expect("reqwest::Client with timeout should build")
}
pub(crate) async fn prewarm_with(client: &reqwest::Client, base_url: &str) {
let _ = tokio::time::timeout(Duration::from_secs(10), client.head(base_url).send()).await;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_request_timeout_is_ten_minutes() {
assert_eq!(DEFAULT_REQUEST_TIMEOUT, Duration::from_secs(600));
}
}