1use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct NormalizedMessage {
20 pub id: Uuid,
22 pub role: MessageRole,
24 pub content: String,
26 pub source_provider: String,
28 pub source_model: Option<String>,
30 pub attachments: Vec<Attachment>,
32 pub tool_calls: Vec<ToolCall>,
34 pub token_count: Option<usize>,
36 pub timestamp: DateTime<Utc>,
38 pub metadata: HashMap<String, serde_json::Value>,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(rename_all = "lowercase")]
45pub enum MessageRole {
46 User,
47 Assistant,
48 System,
49 Tool,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Attachment {
55 pub id: Uuid,
57 pub attachment_type: AttachmentType,
59 pub name: Option<String>,
61 pub mime_type: String,
63 pub content: String,
65 pub url: Option<String>,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71#[serde(rename_all = "snake_case")]
72pub enum AttachmentType {
73 Image,
74 File,
75 Code,
76 Audio,
77 Video,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ToolCall {
83 pub id: String,
85 pub name: String,
87 pub arguments: serde_json::Value,
89 pub result: Option<String>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ConversationContext {
100 pub id: Uuid,
102 pub title: String,
104 pub system_prompt: Option<String>,
106 pub messages: Vec<NormalizedMessage>,
108 pub summary: Option<ConversationSummary>,
110 pub tools: Vec<ToolDefinition>,
112 pub provider_history: Vec<ProviderSwitch>,
114 pub created_at: DateTime<Utc>,
116 pub updated_at: DateTime<Utc>,
118 pub metadata: HashMap<String, serde_json::Value>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ConversationSummary {
125 pub text: String,
127 pub topics: Vec<String>,
129 pub entities: Vec<String>,
131 pub goals: Vec<String>,
133 pub up_to_message_id: Uuid,
135 pub generated_at: DateTime<Utc>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ToolDefinition {
142 pub name: String,
144 pub description: String,
146 pub parameters: serde_json::Value,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ProviderSwitch {
153 pub from_provider: String,
155 pub from_model: Option<String>,
157 pub to_provider: String,
159 pub to_model: Option<String>,
161 pub reason: Option<String>,
163 pub switched_at: DateTime<Utc>,
165}
166
167pub trait ProviderAdapter: Send + Sync {
173 fn provider_name(&self) -> &str;
175
176 fn to_provider_format(&self, context: &ConversationContext) -> ProviderMessages;
178
179 fn from_provider_format(&self, response: &ProviderResponse) -> NormalizedMessage;
181
182 fn capabilities(&self) -> ProviderCapabilities;
184
185 fn estimate_tokens(&self, messages: &[NormalizedMessage]) -> usize;
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct ProviderMessages {
192 pub messages: Vec<serde_json::Value>,
194 pub system: Option<String>,
196 pub tools: Option<Vec<serde_json::Value>>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct ProviderResponse {
203 pub provider: String,
205 pub model: String,
207 pub content: String,
209 pub tool_calls: Vec<ToolCall>,
211 pub usage: Option<UsageStats>,
213 pub raw: serde_json::Value,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct UsageStats {
220 pub prompt_tokens: usize,
221 pub completion_tokens: usize,
222 pub total_tokens: usize,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct ProviderCapabilities {
228 pub vision: bool,
230 pub tools: bool,
232 pub system_messages: bool,
234 pub max_context: usize,
236 pub streaming: bool,
238}
239
240pub struct OpenAIAdapter;
246
247impl ProviderAdapter for OpenAIAdapter {
248 fn provider_name(&self) -> &str {
249 "openai"
250 }
251
252 fn to_provider_format(&self, context: &ConversationContext) -> ProviderMessages {
253 let mut messages: Vec<serde_json::Value> = vec![];
254
255 if let Some(ref system) = context.system_prompt {
257 messages.push(serde_json::json!({
258 "role": "system",
259 "content": system
260 }));
261 }
262
263 for msg in &context.messages {
265 let role = match msg.role {
266 MessageRole::User => "user",
267 MessageRole::Assistant => "assistant",
268 MessageRole::System => "system",
269 MessageRole::Tool => "tool",
270 };
271
272 let mut message = serde_json::json!({
273 "role": role,
274 "content": msg.content
275 });
276
277 if !msg.tool_calls.is_empty() && msg.role == MessageRole::Assistant {
279 message["tool_calls"] = serde_json::json!(msg.tool_calls.iter().map(|tc| {
280 serde_json::json!({
281 "id": tc.id,
282 "type": "function",
283 "function": {
284 "name": tc.name,
285 "arguments": tc.arguments.to_string()
286 }
287 })
288 }).collect::<Vec<_>>());
289 }
290
291 if !msg.attachments.is_empty() {
293 let content_parts: Vec<serde_json::Value> = std::iter::once(
294 serde_json::json!({ "type": "text", "text": msg.content })
295 ).chain(msg.attachments.iter().filter(|a| a.attachment_type == AttachmentType::Image).map(|a| {
296 if let Some(ref url) = a.url {
297 serde_json::json!({
298 "type": "image_url",
299 "image_url": { "url": url }
300 })
301 } else {
302 serde_json::json!({
303 "type": "image_url",
304 "image_url": { "url": format!("data:{};base64,{}", a.mime_type, a.content) }
305 })
306 }
307 })).collect();
308
309 message["content"] = serde_json::json!(content_parts);
310 }
311
312 messages.push(message);
313 }
314
315 let tools = if !context.tools.is_empty() {
317 Some(context.tools.iter().map(|t| {
318 serde_json::json!({
319 "type": "function",
320 "function": {
321 "name": t.name,
322 "description": t.description,
323 "parameters": t.parameters
324 }
325 })
326 }).collect())
327 } else {
328 None
329 };
330
331 ProviderMessages {
332 messages,
333 system: None, tools,
335 }
336 }
337
338 fn from_provider_format(&self, response: &ProviderResponse) -> NormalizedMessage {
339 NormalizedMessage {
340 id: Uuid::new_v4(),
341 role: MessageRole::Assistant,
342 content: response.content.clone(),
343 source_provider: "openai".to_string(),
344 source_model: Some(response.model.clone()),
345 attachments: vec![],
346 tool_calls: response.tool_calls.clone(),
347 token_count: response.usage.as_ref().map(|u| u.completion_tokens),
348 timestamp: Utc::now(),
349 metadata: HashMap::new(),
350 }
351 }
352
353 fn capabilities(&self) -> ProviderCapabilities {
354 ProviderCapabilities {
355 vision: true,
356 tools: true,
357 system_messages: true,
358 max_context: 128000,
359 streaming: true,
360 }
361 }
362
363 fn estimate_tokens(&self, messages: &[NormalizedMessage]) -> usize {
364 messages.iter().map(|m| m.content.len() / 4).sum()
366 }
367}
368
369pub struct AnthropicAdapter;
375
376impl ProviderAdapter for AnthropicAdapter {
377 fn provider_name(&self) -> &str {
378 "anthropic"
379 }
380
381 fn to_provider_format(&self, context: &ConversationContext) -> ProviderMessages {
382 let mut messages: Vec<serde_json::Value> = vec![];
383
384 for msg in &context.messages {
385 let role = match msg.role {
387 MessageRole::User | MessageRole::Tool => "user",
388 MessageRole::Assistant => "assistant",
389 MessageRole::System => continue, };
391
392 let mut content_parts: Vec<serde_json::Value> = vec![];
393
394 content_parts.push(serde_json::json!({
396 "type": "text",
397 "text": msg.content
398 }));
399
400 for attachment in &msg.attachments {
402 if attachment.attachment_type == AttachmentType::Image {
403 content_parts.push(serde_json::json!({
404 "type": "image",
405 "source": {
406 "type": "base64",
407 "media_type": attachment.mime_type,
408 "data": attachment.content
409 }
410 }));
411 }
412 }
413
414 messages.push(serde_json::json!({
415 "role": role,
416 "content": content_parts
417 }));
418 }
419
420 let tools = if !context.tools.is_empty() {
422 Some(context.tools.iter().map(|t| {
423 serde_json::json!({
424 "name": t.name,
425 "description": t.description,
426 "input_schema": t.parameters
427 })
428 }).collect())
429 } else {
430 None
431 };
432
433 ProviderMessages {
434 messages,
435 system: context.system_prompt.clone(),
436 tools,
437 }
438 }
439
440 fn from_provider_format(&self, response: &ProviderResponse) -> NormalizedMessage {
441 NormalizedMessage {
442 id: Uuid::new_v4(),
443 role: MessageRole::Assistant,
444 content: response.content.clone(),
445 source_provider: "anthropic".to_string(),
446 source_model: Some(response.model.clone()),
447 attachments: vec![],
448 tool_calls: response.tool_calls.clone(),
449 token_count: response.usage.as_ref().map(|u| u.completion_tokens),
450 timestamp: Utc::now(),
451 metadata: HashMap::new(),
452 }
453 }
454
455 fn capabilities(&self) -> ProviderCapabilities {
456 ProviderCapabilities {
457 vision: true,
458 tools: true,
459 system_messages: true,
460 max_context: 200000,
461 streaming: true,
462 }
463 }
464
465 fn estimate_tokens(&self, messages: &[NormalizedMessage]) -> usize {
466 messages.iter().map(|m| m.content.len() / 4).sum()
467 }
468}
469
470pub struct ContinuationManager {
476 adapters: HashMap<String, Box<dyn ProviderAdapter>>,
478 contexts: HashMap<Uuid, ConversationContext>,
480}
481
482impl ContinuationManager {
483 pub fn new() -> Self {
485 let mut adapters: HashMap<String, Box<dyn ProviderAdapter>> = HashMap::new();
486 adapters.insert("openai".to_string(), Box::new(OpenAIAdapter));
487 adapters.insert("anthropic".to_string(), Box::new(AnthropicAdapter));
488
489 Self {
490 adapters,
491 contexts: HashMap::new(),
492 }
493 }
494
495 pub fn register_adapter(&mut self, adapter: Box<dyn ProviderAdapter>) {
497 self.adapters.insert(adapter.provider_name().to_string(), adapter);
498 }
499
500 pub fn create_context(&mut self, title: &str, system_prompt: Option<&str>) -> Uuid {
502 let id = Uuid::new_v4();
503 let context = ConversationContext {
504 id,
505 title: title.to_string(),
506 system_prompt: system_prompt.map(String::from),
507 messages: vec![],
508 summary: None,
509 tools: vec![],
510 provider_history: vec![],
511 created_at: Utc::now(),
512 updated_at: Utc::now(),
513 metadata: HashMap::new(),
514 };
515 self.contexts.insert(id, context);
516 id
517 }
518
519 pub fn add_message(&mut self, context_id: Uuid, message: NormalizedMessage) -> bool {
521 if let Some(context) = self.contexts.get_mut(&context_id) {
522 context.messages.push(message);
523 context.updated_at = Utc::now();
524 true
525 } else {
526 false
527 }
528 }
529
530 pub fn switch_provider(
532 &mut self,
533 context_id: Uuid,
534 to_provider: &str,
535 to_model: Option<&str>,
536 reason: Option<&str>,
537 ) -> Option<ProviderMessages> {
538 let context = self.contexts.get_mut(&context_id)?;
539 let adapter = self.adapters.get(to_provider)?;
540
541 let last_provider = context.provider_history.last();
543 let switch = ProviderSwitch {
544 from_provider: last_provider.map(|p| p.to_provider.clone()).unwrap_or_default(),
545 from_model: last_provider.and_then(|p| p.to_model.clone()),
546 to_provider: to_provider.to_string(),
547 to_model: to_model.map(String::from),
548 reason: reason.map(String::from),
549 switched_at: Utc::now(),
550 };
551 context.provider_history.push(switch);
552 context.updated_at = Utc::now();
553
554 Some(adapter.to_provider_format(context))
556 }
557
558 pub fn get_provider_messages(&self, context_id: Uuid, provider: &str) -> Option<ProviderMessages> {
560 let context = self.contexts.get(&context_id)?;
561 let adapter = self.adapters.get(provider)?;
562 Some(adapter.to_provider_format(context))
563 }
564
565 pub fn process_response(&mut self, context_id: Uuid, response: &ProviderResponse) -> Option<NormalizedMessage> {
567 let adapter = self.adapters.get(&response.provider)?;
568 let message = adapter.from_provider_format(response);
569
570 if let Some(context) = self.contexts.get_mut(&context_id) {
571 context.messages.push(message.clone());
572 context.updated_at = Utc::now();
573 }
574
575 Some(message)
576 }
577
578 pub fn get_context(&self, context_id: Uuid) -> Option<&ConversationContext> {
580 self.contexts.get(&context_id)
581 }
582
583 pub fn estimate_tokens(&self, context_id: Uuid, provider: &str) -> Option<usize> {
585 let context = self.contexts.get(&context_id)?;
586 let adapter = self.adapters.get(provider)?;
587 Some(adapter.estimate_tokens(&context.messages))
588 }
589
590 pub fn compress_context(&mut self, context_id: Uuid, summary_text: &str, topics: Vec<String>) -> bool {
592 if let Some(context) = self.contexts.get_mut(&context_id) {
593 let last_message_id = context.messages.last().map(|m| m.id).unwrap_or(Uuid::nil());
594 context.summary = Some(ConversationSummary {
595 text: summary_text.to_string(),
596 topics,
597 entities: vec![],
598 goals: vec![],
599 up_to_message_id: last_message_id,
600 generated_at: Utc::now(),
601 });
602 context.updated_at = Utc::now();
603 true
604 } else {
605 false
606 }
607 }
608}
609
610impl Default for ContinuationManager {
611 fn default() -> Self {
612 Self::new()
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_create_context() {
622 let mut manager = ContinuationManager::new();
623 let id = manager.create_context("Test Conversation", Some("You are helpful."));
624
625 let context = manager.get_context(id).unwrap();
626 assert_eq!(context.title, "Test Conversation");
627 assert_eq!(context.system_prompt.as_deref(), Some("You are helpful."));
628 }
629
630 #[test]
631 fn test_add_message() {
632 let mut manager = ContinuationManager::new();
633 let id = manager.create_context("Test", None);
634
635 let message = NormalizedMessage {
636 id: Uuid::new_v4(),
637 role: MessageRole::User,
638 content: "Hello!".to_string(),
639 source_provider: "openai".to_string(),
640 source_model: Some("gpt-4".to_string()),
641 attachments: vec![],
642 tool_calls: vec![],
643 token_count: None,
644 timestamp: Utc::now(),
645 metadata: HashMap::new(),
646 };
647
648 assert!(manager.add_message(id, message));
649 assert_eq!(manager.get_context(id).unwrap().messages.len(), 1);
650 }
651
652 #[test]
653 fn test_provider_switch() {
654 let mut manager = ContinuationManager::new();
655 let id = manager.create_context("Test", Some("System prompt"));
656
657 let message = NormalizedMessage {
659 id: Uuid::new_v4(),
660 role: MessageRole::User,
661 content: "Hello!".to_string(),
662 source_provider: "openai".to_string(),
663 source_model: None,
664 attachments: vec![],
665 tool_calls: vec![],
666 token_count: None,
667 timestamp: Utc::now(),
668 metadata: HashMap::new(),
669 };
670 manager.add_message(id, message);
671
672 let messages = manager.switch_provider(id, "anthropic", Some("claude-sonnet-4-20250514"), Some("Better for writing"));
674 assert!(messages.is_some());
675
676 let context = manager.get_context(id).unwrap();
677 assert_eq!(context.provider_history.len(), 1);
678 assert_eq!(context.provider_history[0].to_provider, "anthropic");
679 }
680}