pub mod anthropic;
pub mod openai;
use std::collections::HashMap;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{Result, tools::Schema};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub tool_name: String,
pub input: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallResult {
pub tool_call_id: String,
pub tool_name: String,
pub output: serde_json::Value,
pub is_error: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessageContent {
Text { text: String },
ToolUse { calls: Vec<ToolCallRequest> },
ToolResult { results: Vec<ToolCallResult> },
}
impl Message {
pub fn user(text: impl Into<String>) -> Self {
Self {
role: Role::User,
content: MessageContent::Text { text: text.into() },
}
}
pub fn system(text: impl Into<String>) -> Self {
Self {
role: Role::System,
content: MessageContent::Text { text: text.into() },
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: MessageContent::Text { text: text.into() },
}
}
pub fn tool_results(results: Vec<ToolCallResult>) -> Self {
Self {
role: Role::User,
content: MessageContent::ToolResult { results },
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: HashMap<String, ParameterDefinition>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterDefinition {
pub kind: String,
pub description: String,
pub required: bool,
}
impl ToolDefinition {
pub fn from_schema(name: &str, schema: &Schema) -> Self {
Self {
name: name.to_string(),
description: schema.description.clone(),
parameters: schema
.parameters
.iter()
.map(|(k, v)| {
(
k.clone(),
ParameterDefinition {
kind: v.kind.clone(),
description: v.description.clone(),
required: v.required,
},
)
})
.collect(),
}
}
}
#[derive(Debug, Clone)]
pub struct Request {
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub system: String,
pub max_tokens: u32,
pub model: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Response {
pub content: ResponseContent,
pub finish_reason: FinishReason,
pub usage: TokenUsage,
}
#[derive(Debug, Clone)]
pub enum ResponseContent {
Text(String),
ToolCalls(Vec<ToolCallRequest>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum FinishReason {
Stop,
ToolUse,
MaxTokens,
Other(String),
}
#[derive(Debug, Clone, Default)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
impl TokenUsage {
pub fn total(&self) -> u32 {
self.input_tokens + self.output_tokens
}
}
#[async_trait]
pub trait Adapter: Send + 'static {
async fn complete(&self, req: Request) -> Result<Response>;
fn model(&self) -> &str;
fn provider(&self) -> &str;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_user() {
let msg = Message::user("hello");
assert_eq!(msg.role, Role::User);
match msg.content {
MessageContent::Text { text } => assert_eq!(text, "hello"),
_ => panic!("expected text content"),
}
}
#[test]
fn test_message_system() {
let msg = Message::system("you are a researcher");
assert_eq!(msg.role, Role::System);
}
#[test]
fn test_message_assistant() {
let msg = Message::assistant("I found the answer");
assert_eq!(msg.role, Role::Assistant);
}
#[test]
fn test_token_usage_total() {
let usage = TokenUsage {
input_tokens: 100,
output_tokens: 50,
};
assert_eq!(usage.total(), 150);
}
#[test]
fn test_tool_definition_from_schema() {
use crate::tools::{Parameter, Schema};
use std::collections::HashMap;
let schema = Schema {
description: "Search the web".to_string(),
parameters: HashMap::from([(
"query".to_string(),
Parameter {
kind: "string".to_string(),
description: "The search query".to_string(),
required: true,
},
)]),
};
let def = ToolDefinition::from_schema("web_search", &schema);
assert_eq!(def.name, "web_search");
assert_eq!(def.description, "Search the web");
assert!(def.parameters.contains_key("query"));
assert!(def.parameters["query"].required);
}
#[test]
fn test_role_serialises_lowercase() {
let role = Role::User;
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, "\"user\"");
let role = Role::Assistant;
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, "\"assistant\"");
}
#[test]
fn test_finish_reason_equality() {
assert_eq!(FinishReason::Stop, FinishReason::Stop);
assert_ne!(FinishReason::Stop, FinishReason::ToolUse);
}
}