use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
impl Message {
pub fn system(content: String) -> Self {
Self {
role: "system".to_string(),
content,
}
}
pub fn user(content: String) -> Self {
Self {
role: "user".to_string(),
content,
}
}
pub fn assistant(content: String) -> Self {
Self {
role: "assistant".to_string(),
content,
}
}
}
#[derive(Debug, Clone)]
pub struct SessionManager {
pub messages: Vec<Message>,
pub max_messages: usize,
pub system_message: Option<Message>,
}
impl Default for SessionManager {
fn default() -> Self {
Self {
messages: Vec::new(),
max_messages: 100,
system_message: None,
}
}
}
impl SessionManager {
pub fn new(max_messages: usize) -> Self {
Self {
messages: Vec::new(),
max_messages,
system_message: None,
}
}
pub fn with_system_message(mut self, content: String) -> Self {
self.system_message = Some(Message::system(content));
self
}
pub fn add_user_message(&mut self, content: String) {
self.add_message(Message::user(content));
}
pub fn add_assistant_message(&mut self, content: String) {
self.add_message(Message::assistant(content));
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
self.trim_if_needed();
}
pub fn replace_with_summary(&mut self, summary: String) {
self.messages.clear();
self.add_message(Message::system(format!(
"Previous conversation summary: {}",
summary
)));
}
pub fn get_messages_for_api(&self) -> Vec<Message> {
let mut api_messages = Vec::new();
if let Some(sys_message) = &self.system_message {
api_messages.push(sys_message.clone());
}
api_messages.extend(self.messages.clone());
api_messages
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
fn trim_if_needed(&mut self) {
if self.messages.len() > self.max_messages {
let to_remove = self.messages.len() - self.max_messages;
self.messages.drain(0..to_remove);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: Option<String>, pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_call_id: String,
pub output: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionOptions {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub max_tokens: Option<u32>,
pub tools: Option<Vec<ToolDefinition>>,
pub json_schema: Option<String>,
pub require_tool_use: bool,
}
impl Default for CompletionOptions {
fn default() -> Self {
Self {
temperature: Some(0.7),
top_p: Some(0.9),
max_tokens: Some(2048),
tools: None,
json_schema: None,
require_tool_use: false,
}
}
}
#[async_trait::async_trait]
pub trait ApiClient: Send + Sync {
#[allow(dead_code)]
async fn complete(&self, messages: Vec<Message>, options: CompletionOptions) -> Result<String>;
async fn complete_with_tools(
&self,
messages: Vec<Message>,
options: CompletionOptions,
tool_results: Option<Vec<ToolResult>>,
) -> Result<(String, Option<Vec<ToolCall>>)>;
}
#[derive(Clone)]
pub enum ApiClientEnum {
Anthropic(Arc<crate::apis::anthropic::AnthropicClient>),
OpenAi(Arc<crate::apis::openai::OpenAIClient>),
Ollama(Arc<crate::apis::ollama::OllamaClient>),
}
impl ApiClientEnum {
#[allow(dead_code)]
pub async fn complete(
&self,
messages: Vec<Message>,
options: CompletionOptions,
) -> Result<String> {
match self {
Self::Anthropic(client) => client.complete(messages, options).await,
Self::OpenAi(client) => client.complete(messages, options).await,
Self::Ollama(client) => client.complete(messages, options).await,
}
}
pub async fn complete_with_tools(
&self,
messages: Vec<Message>,
options: CompletionOptions,
tool_results: Option<Vec<ToolResult>>,
) -> Result<(String, Option<Vec<ToolCall>>)> {
match self {
Self::Anthropic(client) => {
client
.complete_with_tools(messages, options, tool_results)
.await
}
Self::OpenAi(client) => {
client
.complete_with_tools(messages, options, tool_results)
.await
}
Self::Ollama(client) => {
client
.complete_with_tools(messages, options, tool_results)
.await
}
}
}
}
pub type DynApiClient = ApiClientEnum;