use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use crate::wasm::{WasmCompatSend, WasmCompatSync};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Role {
User,
Assistant,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ContentBlock {
Text(String),
Thinking {
thinking: String,
signature: String,
},
RedactedThinking {
data: String,
},
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: Vec<ContentItem>,
is_error: bool,
},
Image {
source: ImageSource,
},
Document {
source: DocumentSource,
},
Compaction {
content: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ContentItem {
Text(String),
Image {
source: ImageSource,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DocumentSource {
Base64Pdf {
data: String,
},
PlainText {
data: String,
},
Url {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<ContentBlock>,
}
impl Message {
#[must_use]
pub fn user(text: impl Into<String>) -> Self {
Self {
role: Role::User,
content: vec![ContentBlock::Text(text.into())],
}
}
#[must_use]
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: vec![ContentBlock::Text(text.into())],
}
}
#[must_use]
pub fn system(text: impl Into<String>) -> Self {
Self {
role: Role::System,
content: vec![ContentBlock::Text(text.into())],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SystemPrompt {
Text(String),
Blocks(Vec<SystemBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemBlock {
pub text: String,
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheControl {
pub ttl: Option<CacheTtl>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CacheTtl {
FiveMinutes,
OneHour,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolChoice {
Auto,
None,
Required,
Specific {
name: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema {
name: String,
schema: serde_json::Value,
strict: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ThinkingConfig {
Enabled {
budget_tokens: usize,
},
Disabled,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ReasoningEffort {
None,
Low,
Medium,
High,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ContextManagement {
pub edits: Vec<ContextEdit>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ContextEdit {
Compact {
strategy: String,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct UsageIteration {
pub input_tokens: usize,
pub output_tokens: usize,
pub cache_read_tokens: Option<usize>,
pub cache_creation_tokens: Option<usize>,
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct EmbeddingRequest {
pub model: String,
pub input: Vec<String>,
pub dimensions: Option<usize>,
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub system: Option<SystemPrompt>,
pub tools: Vec<ToolDefinition>,
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop_sequences: Vec<String>,
pub tool_choice: Option<ToolChoice>,
pub response_format: Option<ResponseFormat>,
pub thinking: Option<ThinkingConfig>,
pub reasoning_effort: Option<ReasoningEffort>,
pub extra: Option<serde_json::Value>,
pub context_management: Option<ContextManagement>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub model: String,
pub message: Message,
pub usage: TokenUsage,
pub stop_reason: StopReason,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum StopReason {
EndTurn,
ToolUse,
MaxTokens,
StopSequence,
ContentFilter,
Compaction,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: usize,
pub output_tokens: usize,
pub cache_read_tokens: Option<usize>,
pub cache_creation_tokens: Option<usize>,
pub reasoning_tokens: Option<usize>,
pub iterations: Option<Vec<UsageIteration>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub title: Option<String>,
pub description: String,
pub input_schema: serde_json::Value,
pub output_schema: Option<serde_json::Value>,
pub annotations: Option<ToolAnnotations>,
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolAnnotations {
pub read_only_hint: Option<bool>,
pub destructive_hint: Option<bool>,
pub idempotent_hint: Option<bool>,
pub open_world_hint: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolOutput {
pub content: Vec<ContentItem>,
pub structured_content: Option<serde_json::Value>,
pub is_error: bool,
}
pub struct ToolContext {
pub cwd: PathBuf,
pub session_id: String,
pub environment: HashMap<String, String>,
pub cancellation_token: CancellationToken,
pub progress_reporter: Option<Arc<dyn ProgressReporter>>,
}
impl Default for ToolContext {
fn default() -> Self {
Self {
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/tmp")),
session_id: String::new(),
environment: HashMap::new(),
cancellation_token: CancellationToken::new(),
progress_reporter: None,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct UsageLimits {
pub request_limit: Option<usize>,
pub tool_calls_limit: Option<usize>,
pub input_tokens_limit: Option<usize>,
pub output_tokens_limit: Option<usize>,
pub total_tokens_limit: Option<usize>,
}
impl UsageLimits {
#[must_use]
pub fn with_request_limit(mut self, limit: usize) -> Self {
self.request_limit = Some(limit);
self
}
#[must_use]
pub fn with_tool_calls_limit(mut self, limit: usize) -> Self {
self.tool_calls_limit = Some(limit);
self
}
#[must_use]
pub fn with_input_tokens_limit(mut self, limit: usize) -> Self {
self.input_tokens_limit = Some(limit);
self
}
#[must_use]
pub fn with_output_tokens_limit(mut self, limit: usize) -> Self {
self.output_tokens_limit = Some(limit);
self
}
#[must_use]
pub fn with_total_tokens_limit(mut self, limit: usize) -> Self {
self.total_tokens_limit = Some(limit);
self
}
}
pub trait ProgressReporter: WasmCompatSend + WasmCompatSync {
fn report(&self, progress: f64, total: Option<f64>, message: Option<&str>);
}
impl From<String> for SystemPrompt {
fn from(s: String) -> Self {
SystemPrompt::Text(s)
}
}
impl From<&str> for SystemPrompt {
fn from(s: &str) -> Self {
SystemPrompt::Text(s.to_string())
}
}