use crate::reasoning::conversation::Conversation;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
ToolCalls,
MaxTokens,
ContentFilter,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema {
schema: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceOptions {
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_definitions: Vec<ToolDefinition>,
#[serde(default = "default_response_format")]
pub response_format: ResponseFormat,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub extra: HashMap<String, serde_json::Value>,
}
fn default_max_tokens() -> u32 {
4096
}
fn default_temperature() -> f32 {
0.3
}
fn default_response_format() -> ResponseFormat {
ResponseFormat::Text
}
impl Default for InferenceOptions {
fn default() -> Self {
Self {
max_tokens: default_max_tokens(),
temperature: default_temperature(),
tool_definitions: Vec::new(),
response_format: ResponseFormat::Text,
model: None,
extra: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceResponse {
pub content: String,
pub tool_calls: Vec<ToolCallRequest>,
pub finish_reason: FinishReason,
pub usage: Usage,
pub model: String,
}
impl InferenceResponse {
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
}
#[derive(Debug, thiserror::Error)]
pub enum InferenceError {
#[error("Provider error: {0}")]
Provider(String),
#[error("Rate limited, retry after {retry_after_ms}ms")]
RateLimited { retry_after_ms: u64 },
#[error("Context window exceeded: {0} tokens requested, {1} available")]
ContextOverflow(usize, usize),
#[error("Model not available: {0}")]
ModelUnavailable(String),
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Timeout after {0:?}")]
Timeout(std::time::Duration),
#[error("Response parse error: {0}")]
ParseError(String),
}
#[async_trait]
pub trait InferenceProvider: Send + Sync {
async fn complete(
&self,
conversation: &Conversation,
options: &InferenceOptions,
) -> Result<InferenceResponse, InferenceError>;
fn provider_name(&self) -> &str;
fn default_model(&self) -> &str;
fn supports_native_tools(&self) -> bool;
fn supports_structured_output(&self) -> bool;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inference_options_default() {
let opts = InferenceOptions::default();
assert_eq!(opts.max_tokens, 4096);
assert!((opts.temperature - 0.3).abs() < f32::EPSILON);
assert!(opts.tool_definitions.is_empty());
assert!(matches!(opts.response_format, ResponseFormat::Text));
}
#[test]
fn test_tool_definition_serde() {
let tool = ToolDefinition {
name: "web_search".into(),
description: "Search the web".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"]
}),
};
let json = serde_json::to_string(&tool).unwrap();
let restored: ToolDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(restored.name, "web_search");
}
#[test]
fn test_response_format_serde() {
let text = ResponseFormat::Text;
let json = serde_json::to_string(&text).unwrap();
assert!(json.contains("text"));
let schema = ResponseFormat::JsonSchema {
schema: serde_json::json!({"type": "object"}),
name: Some("MySchema".into()),
};
let json = serde_json::to_string(&schema).unwrap();
assert!(json.contains("json_schema"));
assert!(json.contains("MySchema"));
}
#[test]
fn test_inference_response_has_tool_calls() {
let resp = InferenceResponse {
content: String::new(),
tool_calls: vec![ToolCallRequest {
id: "tc_1".into(),
name: "search".into(),
arguments: "{}".into(),
}],
finish_reason: FinishReason::ToolCalls,
usage: Usage::default(),
model: "test".into(),
};
assert!(resp.has_tool_calls());
let resp_no_tools = InferenceResponse {
content: "Hello".into(),
tool_calls: vec![],
finish_reason: FinishReason::Stop,
usage: Usage::default(),
model: "test".into(),
};
assert!(!resp_no_tools.has_tool_calls());
}
#[test]
fn test_finish_reason_serde() {
let json = serde_json::to_string(&FinishReason::ToolCalls).unwrap();
assert_eq!(json, "\"tool_calls\"");
let restored: FinishReason = serde_json::from_str(&json).unwrap();
assert_eq!(restored, FinishReason::ToolCalls);
}
}