use crate::errors::Result;
use crate::state::Message;
use async_trait::async_trait;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
#[cfg(feature = "openai")]
pub mod openai;
#[cfg(feature = "openrouter")]
pub mod openrouter;
#[cfg(feature = "anthropic")]
pub mod anthropic;
#[cfg(feature = "ollama")]
pub mod ollama;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
impl ToolInfo {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: serde_json::Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
}
}
}
#[derive(Debug, Clone)]
pub struct MessageChunk {
pub content: String,
pub is_final: bool,
pub finish_reason: Option<String>,
}
#[async_trait]
pub trait ChatModel: Send + Sync {
async fn invoke(&self, messages: &[Message]) -> Result<Message>;
async fn stream(
&self,
messages: &[Message],
) -> Result<Pin<Box<dyn Stream<Item = Result<MessageChunk>> + Send>>> {
let message = self.invoke(messages).await?;
let chunk = MessageChunk {
content: message.content,
is_final: true,
finish_reason: Some("stop".to_string()),
};
Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })))
}
fn name(&self) -> &str {
"unknown"
}
fn clone_box(&self) -> Box<dyn ChatModel>;
}
impl Clone for Box<dyn ChatModel> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[derive(Clone)]
struct MockModel;
#[async_trait]
impl ChatModel for MockModel {
async fn invoke(&self, _messages: &[Message]) -> Result<Message> {
Ok(Message::assistant("Mock response"))
}
fn clone_box(&self) -> Box<dyn ChatModel> {
Box::new(self.clone())
}
fn name(&self) -> &str {
"mock"
}
}
#[tokio::test]
async fn test_mock_model() {
let model = MockModel;
let messages = vec![Message::user("Hello")];
let response = model.invoke(&messages).await.unwrap();
assert_eq!(response.content, "Mock response");
assert_eq!(response.role, "assistant");
assert_eq!(model.name(), "mock");
}
#[tokio::test]
async fn test_default_stream() {
let model = MockModel;
let messages = vec![Message::user("Hello")];
let mut stream = model.stream(&messages).await.unwrap();
let chunk = stream.next().await.unwrap().unwrap();
assert!(chunk.is_final);
assert!(chunk.content.contains("Mock response"));
}
}