chasm_cli/agency/
models.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Agency Data Models
4//!
5//! Core data structures for the Agent Development Kit.
6
7#![allow(dead_code)]
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Message in an Agency conversation
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AgencyMessage {
16    /// Unique message ID
17    pub id: String,
18    /// Role: user, assistant, system, tool
19    pub role: MessageRole,
20    /// Message content
21    pub content: String,
22    /// Optional tool calls in this message
23    #[serde(default, skip_serializing_if = "Vec::is_empty")]
24    pub tool_calls: Vec<ToolCall>,
25    /// Optional tool result (if role is tool)
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub tool_result: Option<ToolResult>,
28    /// Message timestamp
29    pub timestamp: DateTime<Utc>,
30    /// Token count (if available)
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub tokens: Option<u32>,
33    /// Associated agent name
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub agent_name: Option<String>,
36    /// Additional metadata
37    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
38    pub metadata: HashMap<String, serde_json::Value>,
39}
40
41/// Message role
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(rename_all = "lowercase")]
44pub enum MessageRole {
45    User,
46    Assistant,
47    System,
48    Tool,
49}
50
51impl std::fmt::Display for MessageRole {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            MessageRole::User => write!(f, "user"),
55            MessageRole::Assistant => write!(f, "assistant"),
56            MessageRole::System => write!(f, "system"),
57            MessageRole::Tool => write!(f, "tool"),
58        }
59    }
60}
61
62/// Tool call request from the model
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ToolCall {
65    /// Unique call ID
66    pub id: String,
67    /// Tool name
68    pub name: String,
69    /// Tool arguments as JSON
70    pub arguments: serde_json::Value,
71    /// Call timestamp
72    #[serde(default = "Utc::now")]
73    pub timestamp: DateTime<Utc>,
74}
75
76/// Result from tool execution
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ToolResult {
79    /// Associated tool call ID
80    pub call_id: String,
81    /// Tool name
82    pub name: String,
83    /// Whether execution succeeded
84    pub success: bool,
85    /// Result content (or error message)
86    pub content: String,
87    /// Execution duration in milliseconds
88    #[serde(default)]
89    pub duration_ms: u64,
90    /// Additional output data
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub data: Option<serde_json::Value>,
93}
94
95/// Event emitted during agent execution
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct AgencyEvent {
98    /// Event type
99    pub event_type: EventType,
100    /// Associated agent name
101    pub agent_name: String,
102    /// Event data
103    pub data: serde_json::Value,
104    /// Event timestamp
105    pub timestamp: DateTime<Utc>,
106    /// Session ID
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub session_id: Option<String>,
109}
110
111/// Types of events during execution - matches csm-shared AgencyEventType
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
113#[serde(rename_all = "snake_case")]
114pub enum EventType {
115    /// Agent started processing
116    AgentStarted,
117    /// Agent is thinking
118    AgentThinking,
119    /// Agent is executing
120    AgentExecuting,
121    /// Agent completed processing
122    AgentCompleted,
123    /// Agent failed
124    AgentFailed,
125    /// Tool call started
126    ToolCallStarted,
127    /// Tool call completed
128    ToolCallCompleted,
129    /// Tool call failed
130    ToolCallFailed,
131    /// Message created
132    MessageCreated,
133    /// Message delta (streaming)
134    MessageDelta,
135    /// Task created
136    TaskCreated,
137    /// Task started
138    TaskStarted,
139    /// Task completed
140    TaskCompleted,
141    /// Task failed
142    TaskFailed,
143    /// Swarm started
144    SwarmStarted,
145    /// Agent joined swarm
146    SwarmAgentJoined,
147    /// Swarm completed
148    SwarmCompleted,
149    /// Swarm failed
150    SwarmFailed,
151    /// Agent handoff to another agent
152    Handoff,
153    /// Error occurred
154    Error,
155}
156
157impl std::fmt::Display for EventType {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        match self {
160            EventType::AgentStarted => write!(f, "agent_started"),
161            EventType::AgentThinking => write!(f, "agent_thinking"),
162            EventType::AgentExecuting => write!(f, "agent_executing"),
163            EventType::AgentCompleted => write!(f, "agent_completed"),
164            EventType::AgentFailed => write!(f, "agent_failed"),
165            EventType::ToolCallStarted => write!(f, "tool_call_started"),
166            EventType::ToolCallCompleted => write!(f, "tool_call_completed"),
167            EventType::ToolCallFailed => write!(f, "tool_call_failed"),
168            EventType::MessageCreated => write!(f, "message_created"),
169            EventType::MessageDelta => write!(f, "message_delta"),
170            EventType::TaskCreated => write!(f, "task_created"),
171            EventType::TaskStarted => write!(f, "task_started"),
172            EventType::TaskCompleted => write!(f, "task_completed"),
173            EventType::TaskFailed => write!(f, "task_failed"),
174            EventType::SwarmStarted => write!(f, "swarm_started"),
175            EventType::SwarmAgentJoined => write!(f, "swarm_agent_joined"),
176            EventType::SwarmCompleted => write!(f, "swarm_completed"),
177            EventType::SwarmFailed => write!(f, "swarm_failed"),
178            EventType::Handoff => write!(f, "handoff"),
179            EventType::Error => write!(f, "error"),
180        }
181    }
182}
183
184/// Token usage statistics
185#[derive(Debug, Clone, Default, Serialize, Deserialize)]
186pub struct TokenUsage {
187    /// Prompt/input tokens
188    pub prompt_tokens: u32,
189    /// Completion/output tokens
190    pub completion_tokens: u32,
191    /// Total tokens
192    pub total_tokens: u32,
193}
194
195impl TokenUsage {
196    pub fn new(prompt: u32, completion: u32) -> Self {
197        Self {
198            prompt_tokens: prompt,
199            completion_tokens: completion,
200            total_tokens: prompt + completion,
201        }
202    }
203
204    pub fn add(&mut self, other: &TokenUsage) {
205        self.prompt_tokens += other.prompt_tokens;
206        self.completion_tokens += other.completion_tokens;
207        self.total_tokens += other.total_tokens;
208    }
209}
210
211/// Model configuration for an agent
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct ModelConfig {
214    /// Model identifier (e.g., "gemini-2.5-flash", "gpt-4o")
215    pub model: String,
216    /// Provider type
217    #[serde(default)]
218    pub provider: ModelProvider,
219    /// API endpoint (if custom)
220    #[serde(default, skip_serializing_if = "Option::is_none")]
221    pub endpoint: Option<String>,
222    /// API key (if not using environment variable)
223    #[serde(default, skip_serializing_if = "Option::is_none")]
224    pub api_key: Option<String>,
225    /// Temperature (0.0 - 2.0)
226    #[serde(default = "default_temperature")]
227    pub temperature: f32,
228    /// Max output tokens
229    #[serde(default, skip_serializing_if = "Option::is_none")]
230    pub max_tokens: Option<u32>,
231    /// Top-p sampling
232    #[serde(default, skip_serializing_if = "Option::is_none")]
233    pub top_p: Option<f32>,
234}
235
236fn default_temperature() -> f32 {
237    0.7
238}
239
240impl Default for ModelConfig {
241    fn default() -> Self {
242        Self {
243            model: "gemini-2.5-flash".to_string(),
244            provider: ModelProvider::Google,
245            endpoint: None,
246            api_key: None,
247            temperature: 0.7,
248            max_tokens: None,
249            top_p: None,
250        }
251    }
252}
253
254/// Supported model providers
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
256#[serde(rename_all = "lowercase")]
257pub enum ModelProvider {
258    // Cloud Providers
259    #[default]
260    Google,
261    OpenAI,
262    Anthropic,
263    Azure,
264    Groq,
265    Together,
266    Fireworks,
267    DeepSeek,
268    Mistral,
269    Cohere,
270    Perplexity,
271
272    // Local Providers
273    Ollama,
274    LMStudio,
275    Jan,
276    GPT4All,
277    LocalAI,
278    Llamafile,
279    TextGenWebUI,
280    VLLM,
281    KoboldCpp,
282    TabbyML,
283    Exo,
284
285    // Generic
286    OpenAICompatible,
287    Custom,
288}
289
290impl ModelProvider {
291    /// Get the default endpoint for this provider
292    pub fn default_endpoint(&self) -> Option<&'static str> {
293        match self {
294            // Cloud Providers
295            ModelProvider::Google => Some("https://generativelanguage.googleapis.com/v1"),
296            ModelProvider::OpenAI => Some("https://api.openai.com/v1"),
297            ModelProvider::Anthropic => Some("https://api.anthropic.com/v1"),
298            ModelProvider::Azure => None, // Requires custom endpoint
299            ModelProvider::Groq => Some("https://api.groq.com/openai/v1"),
300            ModelProvider::Together => Some("https://api.together.xyz/v1"),
301            ModelProvider::Fireworks => Some("https://api.fireworks.ai/inference/v1"),
302            ModelProvider::DeepSeek => Some("https://api.deepseek.com/v1"),
303            ModelProvider::Mistral => Some("https://api.mistral.ai/v1"),
304            ModelProvider::Cohere => Some("https://api.cohere.ai/v1"),
305            ModelProvider::Perplexity => Some("https://api.perplexity.ai"),
306
307            // Local Providers
308            ModelProvider::Ollama => Some("http://localhost:11434"),
309            ModelProvider::LMStudio => Some("http://localhost:1234/v1"),
310            ModelProvider::Jan => Some("http://localhost:1337/v1"),
311            ModelProvider::GPT4All => Some("http://localhost:4891/v1"),
312            ModelProvider::LocalAI => Some("http://localhost:8080/v1"),
313            ModelProvider::Llamafile => Some("http://localhost:8080/v1"),
314            ModelProvider::TextGenWebUI => Some("http://localhost:5000/v1"),
315            ModelProvider::VLLM => Some("http://localhost:8000/v1"),
316            ModelProvider::KoboldCpp => Some("http://localhost:5001/v1"),
317            ModelProvider::TabbyML => Some("http://localhost:8080/v1"),
318            ModelProvider::Exo => Some("http://localhost:52415/v1"),
319
320            // Generic
321            ModelProvider::OpenAICompatible => None, // Requires custom endpoint
322            ModelProvider::Custom => None,
323        }
324    }
325
326    /// Check if this provider is a local provider
327    pub fn is_local(&self) -> bool {
328        matches!(
329            self,
330            ModelProvider::Ollama
331                | ModelProvider::LMStudio
332                | ModelProvider::Jan
333                | ModelProvider::GPT4All
334                | ModelProvider::LocalAI
335                | ModelProvider::Llamafile
336                | ModelProvider::TextGenWebUI
337                | ModelProvider::VLLM
338                | ModelProvider::KoboldCpp
339                | ModelProvider::TabbyML
340                | ModelProvider::Exo
341        )
342    }
343
344    /// Check if this provider uses OpenAI-compatible API
345    pub fn is_openai_compatible(&self) -> bool {
346        matches!(
347            self,
348            ModelProvider::OpenAI
349                | ModelProvider::Azure
350                | ModelProvider::Groq
351                | ModelProvider::Together
352                | ModelProvider::Fireworks
353                | ModelProvider::DeepSeek
354                | ModelProvider::Mistral
355                | ModelProvider::Perplexity
356                | ModelProvider::LMStudio
357                | ModelProvider::Jan
358                | ModelProvider::GPT4All
359                | ModelProvider::LocalAI
360                | ModelProvider::Llamafile
361                | ModelProvider::TextGenWebUI
362                | ModelProvider::VLLM
363                | ModelProvider::KoboldCpp
364                | ModelProvider::TabbyML
365                | ModelProvider::Exo
366                | ModelProvider::OpenAICompatible
367        )
368    }
369}
370
371impl std::fmt::Display for ModelProvider {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        match self {
374            // Cloud Providers
375            ModelProvider::Google => write!(f, "google"),
376            ModelProvider::OpenAI => write!(f, "openai"),
377            ModelProvider::Anthropic => write!(f, "anthropic"),
378            ModelProvider::Azure => write!(f, "azure"),
379            ModelProvider::Groq => write!(f, "groq"),
380            ModelProvider::Together => write!(f, "together"),
381            ModelProvider::Fireworks => write!(f, "fireworks"),
382            ModelProvider::DeepSeek => write!(f, "deepseek"),
383            ModelProvider::Mistral => write!(f, "mistral"),
384            ModelProvider::Cohere => write!(f, "cohere"),
385            ModelProvider::Perplexity => write!(f, "perplexity"),
386            // Local Providers
387            ModelProvider::Ollama => write!(f, "ollama"),
388            ModelProvider::LMStudio => write!(f, "lmstudio"),
389            ModelProvider::Jan => write!(f, "jan"),
390            ModelProvider::GPT4All => write!(f, "gpt4all"),
391            ModelProvider::LocalAI => write!(f, "localai"),
392            ModelProvider::Llamafile => write!(f, "llamafile"),
393            ModelProvider::TextGenWebUI => write!(f, "textgenwebui"),
394            ModelProvider::VLLM => write!(f, "vllm"),
395            ModelProvider::KoboldCpp => write!(f, "koboldcpp"),
396            ModelProvider::TabbyML => write!(f, "tabbyml"),
397            ModelProvider::Exo => write!(f, "exo"),
398            // Generic
399            ModelProvider::OpenAICompatible => write!(f, "openai_compatible"),
400            ModelProvider::Custom => write!(f, "custom"),
401        }
402    }
403}