use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationConfig {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub stop_sequences: Option<Vec<String>>,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
temperature: Some(0.7),
max_tokens: Some(2048),
stop_sequences: None,
top_p: Some(1.0),
frequency_penalty: None,
presence_penalty: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub function_name: String,
pub arguments: serde_json::Value,
}
pub fn parse_thinking_tokens(response: &str) -> (Option<String>, String) {
let think_pattern = regex::Regex::new(r"<think>([\s\S]*?)</think>").unwrap();
let reasoning = if let Some(captures) = think_pattern.captures(response) {
captures.get(1).map(|m| m.as_str().trim().to_string())
} else {
None
};
let content = if let Some(end_idx) = response.find("</think>") {
let after_think = &response[end_idx + "</think>".len()..];
after_think.trim().to_string()
} else {
response.to_string()
};
(reasoning, content)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelResponse {
pub content: String,
pub model: String,
pub usage: Option<TokenUsage>,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageAttachment {
pub mime: String,
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderMetadata {
pub name: String,
pub supported_models: Vec<String>,
pub supports_streaming: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderKind {
Mock,
#[cfg(feature = "openai")]
OpenAI,
#[cfg(feature = "anthropic")]
Anthropic,
#[cfg(feature = "ollama")]
Ollama,
#[cfg(feature = "mlx")]
MLX,
#[cfg(feature = "lmstudio")]
LMStudio,
}
impl ProviderKind {
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"mock" => Some(ProviderKind::Mock),
#[cfg(feature = "openai")]
"openai" => Some(ProviderKind::OpenAI),
#[cfg(feature = "anthropic")]
"anthropic" => Some(ProviderKind::Anthropic),
#[cfg(feature = "ollama")]
"ollama" => Some(ProviderKind::Ollama),
#[cfg(feature = "mlx")]
"mlx" => Some(ProviderKind::MLX),
#[cfg(feature = "lmstudio")]
"lmstudio" => Some(ProviderKind::LMStudio),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
ProviderKind::Mock => "mock",
#[cfg(feature = "openai")]
ProviderKind::OpenAI => "openai",
#[cfg(feature = "anthropic")]
ProviderKind::Anthropic => "anthropic",
#[cfg(feature = "ollama")]
ProviderKind::Ollama => "ollama",
#[cfg(feature = "mlx")]
ProviderKind::MLX => "mlx",
#[cfg(feature = "lmstudio")]
ProviderKind::LMStudio => "lmstudio",
}
}
}
#[async_trait]
pub trait ModelProvider: Send + Sync {
async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse>;
async fn generate_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
config: &GenerationConfig,
) -> Result<ModelResponse> {
let _ = attachments;
self.generate(prompt, config).await
}
async fn stream(
&self,
prompt: &str,
config: &GenerationConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>>;
async fn stream_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
config: &GenerationConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
let _ = attachments;
self.stream(prompt, config).await
}
fn metadata(&self) -> ProviderMetadata;
fn kind(&self) -> ProviderKind;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_kind_from_str() {
assert_eq!(ProviderKind::from_str("mock"), Some(ProviderKind::Mock));
assert_eq!(ProviderKind::from_str("Mock"), Some(ProviderKind::Mock));
assert_eq!(ProviderKind::from_str("MOCK"), Some(ProviderKind::Mock));
assert_eq!(ProviderKind::from_str("invalid"), None);
}
#[test]
fn test_provider_kind_as_str() {
assert_eq!(ProviderKind::Mock.as_str(), "mock");
}
#[test]
fn test_generation_config_default() {
let config = GenerationConfig::default();
assert_eq!(config.temperature, Some(0.7));
assert_eq!(config.max_tokens, Some(2048));
assert_eq!(config.top_p, Some(1.0));
}
#[test]
fn test_generation_config_serialization() {
let config = GenerationConfig {
temperature: Some(0.9),
max_tokens: Some(1024),
stop_sequences: Some(vec!["STOP".to_string()]),
top_p: Some(0.95),
frequency_penalty: None,
presence_penalty: None,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: GenerationConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.temperature, deserialized.temperature);
assert_eq!(config.max_tokens, deserialized.max_tokens);
}
#[test]
fn test_parse_thinking_tokens_with_tags() {
let response = "<think>Let me consider this carefully...</think>Here's my final answer.";
let (reasoning, content) = parse_thinking_tokens(response);
assert_eq!(
reasoning,
Some("Let me consider this carefully...".to_string())
);
assert_eq!(content, "Here's my final answer.");
}
#[test]
fn test_parse_thinking_tokens_without_tags() {
let response = "This is a normal response without thinking tags.";
let (reasoning, content) = parse_thinking_tokens(response);
assert_eq!(reasoning, None);
assert_eq!(content, "This is a normal response without thinking tags.");
}
#[test]
fn test_parse_thinking_tokens_multiline() {
let response = "<think>\nFirst, I need to analyze the problem.\nThen I'll formulate a solution.\n</think>\n\nHere's the answer: 42";
let (reasoning, content) = parse_thinking_tokens(response);
assert!(reasoning.is_some());
let reasoning_text = reasoning.unwrap();
assert!(reasoning_text.contains("analyze the problem"));
assert!(reasoning_text.contains("formulate a solution"));
assert_eq!(content, "Here's the answer: 42");
}
#[test]
fn test_parse_thinking_tokens_empty_think() {
let response = "<think></think>Content after empty think.";
let (reasoning, content) = parse_thinking_tokens(response);
assert_eq!(reasoning, Some("".to_string()));
assert_eq!(content, "Content after empty think.");
}
#[test]
fn test_parse_thinking_tokens_whitespace_handling() {
let response = "<think> \n Some reasoning \n </think> \n Final answer";
let (reasoning, content) = parse_thinking_tokens(response);
assert_eq!(reasoning, Some("Some reasoning".to_string()));
assert_eq!(content, "Final answer");
}
#[test]
fn test_parse_thinking_tokens_incomplete_tag() {
let response = "<think>Incomplete thinking...";
let (reasoning, content) = parse_thinking_tokens(response);
assert_eq!(reasoning, None);
assert_eq!(content, "<think>Incomplete thinking...");
}
}