use async_trait::async_trait;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::{borrow::Borrow, pin::Pin};
use super::ChatError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ContentPart {
Text(String),
Image {
mime_type: String,
data: Vec<u8>,
},
}
impl ContentPart {
pub fn into_text(self) -> Option<String> {
match self {
ContentPart::Text(text) => Some(text),
_ => None,
}
}
}
impl Borrow<str> for ContentPart {
fn borrow(&self) -> &str {
match self {
ContentPart::Text(s) => &*s,
ContentPart::Image { .. } => "[Image]",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolResult {
pub call_id: String,
pub name: String,
pub content: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Message {
System(Vec<ContentPart>),
User(Vec<ContentPart>),
Assistant {
content: Vec<ContentPart>,
#[serde(default)] tool_calls: Vec<ToolCallRequest>,
},
Tool(Vec<ToolResult>),
}
impl Message {
pub fn system(text: impl Into<String>) -> Self {
Message::System(vec![ContentPart::Text(text.into())])
}
pub fn user(text: impl Into<String>) -> Self {
Message::User(vec![ContentPart::Text(text.into())])
}
pub fn assistant(text: impl Into<String>) -> Self {
Message::Assistant {
content: vec![ContentPart::Text(text.into())],
tool_calls: Vec::new(), }
}
pub fn assistant_response(content: Vec<ContentPart>, tool_calls: Vec<ToolCallRequest>) -> Self {
Message::Assistant { content, tool_calls }
}
pub fn tool(results: Vec<ToolResult>) -> Self {
Message::Tool(results)
}
pub fn text_content(&self) -> String {
match self {
Message::System(content) => content.join(""),
Message::User(content) => content.join(""),
Message::Assistant { content, .. } => content.join(""),
Message::Tool(..) => String::new(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolParameterSchema {
#[serde(rename = "type")]
pub schema_type: String, #[serde(default)]
pub properties: serde_json::Map<String, JsonValue>,
#[serde(default)]
pub required: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: ToolParameterSchema,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
Auto,
None,
Required, Tool { name: String },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub name: String,
pub arguments: JsonValue,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChatOptions {
pub model_id: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
#[serde(default)]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(default)]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct UsageInfo {
pub prompt_tokens: Option<u32>,
pub completion_tokens: Option<u32>,
pub total_tokens: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ContentFilter,
ToolCalls,
Unspecified,
Other(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ChatResponse {
pub content: Vec<ContentPart>,
#[serde(default)]
pub tool_calls: Vec<ToolCallRequest>,
pub usage: Option<UsageInfo>,
pub finish_reason: Option<FinishReason>,
pub model_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub description: Option<String>,
pub context_window: Option<u32>,
pub max_output_tokens: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum StreamChunk {
Text(String),
ToolCall(ToolCallRequest),
Usage(UsageInfo),
StreamEnd {
finish_reason: FinishReason,
usage: Option<UsageInfo>, },
StreamError { message: String, code: Option<String> },
ProviderSpecific { kind: String, data: JsonValue },
}
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, ChatError>> + Send>>;
#[async_trait]
pub trait ChatApi: Send + Sync {
async fn list_models(&self) -> Result<Vec<ModelInfo>, ChatError>;
async fn generate(&self, messages: &[Message], options: &ChatOptions) -> Result<ChatResponse, ChatError>;
async fn generate_stream(&self, messages: &[Message], options: &ChatOptions) -> Result<ChatStream, ChatError>;
}