Skip to main content

chasm/routing/
continuation.rs

1// Copyright (c) 2024-2027 Nervosys LLC
2// SPDX-License-Identifier: AGPL-3.0-only
3//! Conversation continuation across providers
4//!
5//! Enables seamless continuation of conversations when switching between
6//! different AI providers or models.
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13// ============================================================================
14// Conversation State
15// ============================================================================
16
17/// Normalized message format for cross-provider continuity
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct NormalizedMessage {
20    /// Unique message ID
21    pub id: Uuid,
22    /// Role (user, assistant, system, tool)
23    pub role: MessageRole,
24    /// Message content
25    pub content: String,
26    /// Original provider
27    pub source_provider: String,
28    /// Original model
29    pub source_model: Option<String>,
30    /// Attachments (images, files)
31    pub attachments: Vec<Attachment>,
32    /// Tool calls made
33    pub tool_calls: Vec<ToolCall>,
34    /// Token count
35    pub token_count: Option<usize>,
36    /// Timestamp
37    pub timestamp: DateTime<Utc>,
38    /// Metadata
39    pub metadata: HashMap<String, serde_json::Value>,
40}
41
42/// Message role
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(rename_all = "lowercase")]
45pub enum MessageRole {
46    User,
47    Assistant,
48    System,
49    Tool,
50}
51
52/// Attachment (image, file, etc.)
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Attachment {
55    /// Attachment ID
56    pub id: Uuid,
57    /// Attachment type
58    pub attachment_type: AttachmentType,
59    /// File name
60    pub name: Option<String>,
61    /// MIME type
62    pub mime_type: String,
63    /// Content (base64 encoded for binary)
64    pub content: String,
65    /// URL if hosted
66    pub url: Option<String>,
67}
68
69/// Attachment type
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71#[serde(rename_all = "snake_case")]
72pub enum AttachmentType {
73    Image,
74    File,
75    Code,
76    Audio,
77    Video,
78}
79
80/// Tool call
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ToolCall {
83    /// Tool call ID
84    pub id: String,
85    /// Tool name
86    pub name: String,
87    /// Arguments as JSON
88    pub arguments: serde_json::Value,
89    /// Result if available
90    pub result: Option<String>,
91}
92
93// ============================================================================
94// Conversation Context
95// ============================================================================
96
97/// Portable conversation context
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ConversationContext {
100    /// Conversation ID
101    pub id: Uuid,
102    /// Title
103    pub title: String,
104    /// System prompt
105    pub system_prompt: Option<String>,
106    /// Normalized messages
107    pub messages: Vec<NormalizedMessage>,
108    /// Summary for context compression
109    pub summary: Option<ConversationSummary>,
110    /// Available tools
111    pub tools: Vec<ToolDefinition>,
112    /// Provider history (list of providers used)
113    pub provider_history: Vec<ProviderSwitch>,
114    /// Creation timestamp
115    pub created_at: DateTime<Utc>,
116    /// Last updated
117    pub updated_at: DateTime<Utc>,
118    /// Metadata
119    pub metadata: HashMap<String, serde_json::Value>,
120}
121
122/// Conversation summary for context compression
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ConversationSummary {
125    /// Summary text
126    pub text: String,
127    /// Key topics discussed
128    pub topics: Vec<String>,
129    /// Important entities mentioned
130    pub entities: Vec<String>,
131    /// User's apparent goals
132    pub goals: Vec<String>,
133    /// Messages summarized up to
134    pub up_to_message_id: Uuid,
135    /// Generated at
136    pub generated_at: DateTime<Utc>,
137}
138
139/// Tool definition for portable tools
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ToolDefinition {
142    /// Tool name
143    pub name: String,
144    /// Description
145    pub description: String,
146    /// Parameters schema (JSON Schema)
147    pub parameters: serde_json::Value,
148}
149
150/// Record of provider switch
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ProviderSwitch {
153    /// From provider
154    pub from_provider: String,
155    /// From model
156    pub from_model: Option<String>,
157    /// To provider
158    pub to_provider: String,
159    /// To model
160    pub to_model: Option<String>,
161    /// Reason for switch
162    pub reason: Option<String>,
163    /// Timestamp
164    pub switched_at: DateTime<Utc>,
165}
166
167// ============================================================================
168// Provider Adapters
169// ============================================================================
170
171/// Provider-specific message format adapter
172pub trait ProviderAdapter: Send + Sync {
173    /// Provider name
174    fn provider_name(&self) -> &str;
175
176    /// Convert normalized messages to provider format
177    fn to_provider_format(&self, context: &ConversationContext) -> ProviderMessages;
178
179    /// Convert provider response to normalized format
180    fn from_provider_format(&self, response: &ProviderResponse) -> NormalizedMessage;
181
182    /// Get supported features
183    fn capabilities(&self) -> ProviderCapabilities;
184
185    /// Estimate tokens for messages
186    fn estimate_tokens(&self, messages: &[NormalizedMessage]) -> usize;
187}
188
189/// Provider-specific messages
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct ProviderMessages {
192    /// Messages in provider format
193    pub messages: Vec<serde_json::Value>,
194    /// System message (if separate)
195    pub system: Option<String>,
196    /// Tools in provider format
197    pub tools: Option<Vec<serde_json::Value>>,
198}
199
200/// Provider response
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct ProviderResponse {
203    /// Provider name
204    pub provider: String,
205    /// Model used
206    pub model: String,
207    /// Response content
208    pub content: String,
209    /// Tool calls
210    pub tool_calls: Vec<ToolCall>,
211    /// Usage statistics
212    pub usage: Option<UsageStats>,
213    /// Raw response
214    pub raw: serde_json::Value,
215}
216
217/// Usage statistics
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct UsageStats {
220    pub prompt_tokens: usize,
221    pub completion_tokens: usize,
222    pub total_tokens: usize,
223}
224
225/// Provider capabilities
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct ProviderCapabilities {
228    /// Supports vision/images
229    pub vision: bool,
230    /// Supports tool/function calling
231    pub tools: bool,
232    /// Supports system messages
233    pub system_messages: bool,
234    /// Maximum context length
235    pub max_context: usize,
236    /// Supports streaming
237    pub streaming: bool,
238}
239
240// ============================================================================
241// OpenAI Adapter
242// ============================================================================
243
244/// OpenAI format adapter
245pub struct OpenAIAdapter;
246
247impl ProviderAdapter for OpenAIAdapter {
248    fn provider_name(&self) -> &str {
249        "openai"
250    }
251
252    fn to_provider_format(&self, context: &ConversationContext) -> ProviderMessages {
253        let mut messages: Vec<serde_json::Value> = vec![];
254
255        // Add system message
256        if let Some(ref system) = context.system_prompt {
257            messages.push(serde_json::json!({
258                "role": "system",
259                "content": system
260            }));
261        }
262
263        // Convert messages
264        for msg in &context.messages {
265            let role = match msg.role {
266                MessageRole::User => "user",
267                MessageRole::Assistant => "assistant",
268                MessageRole::System => "system",
269                MessageRole::Tool => "tool",
270            };
271
272            let mut message = serde_json::json!({
273                "role": role,
274                "content": msg.content
275            });
276
277            // Add tool calls for assistant messages
278            if !msg.tool_calls.is_empty() && msg.role == MessageRole::Assistant {
279                message["tool_calls"] = serde_json::json!(msg.tool_calls.iter().map(|tc| {
280                    serde_json::json!({
281                        "id": tc.id,
282                        "type": "function",
283                        "function": {
284                            "name": tc.name,
285                            "arguments": tc.arguments.to_string()
286                        }
287                    })
288                }).collect::<Vec<_>>());
289            }
290
291            // Add images for vision
292            if !msg.attachments.is_empty() {
293                let content_parts: Vec<serde_json::Value> = std::iter::once(
294                    serde_json::json!({ "type": "text", "text": msg.content })
295                ).chain(msg.attachments.iter().filter(|a| a.attachment_type == AttachmentType::Image).map(|a| {
296                    if let Some(ref url) = a.url {
297                        serde_json::json!({
298                            "type": "image_url",
299                            "image_url": { "url": url }
300                        })
301                    } else {
302                        serde_json::json!({
303                            "type": "image_url",
304                            "image_url": { "url": format!("data:{};base64,{}", a.mime_type, a.content) }
305                        })
306                    }
307                })).collect();
308
309                message["content"] = serde_json::json!(content_parts);
310            }
311
312            messages.push(message);
313        }
314
315        // Convert tools
316        let tools = if !context.tools.is_empty() {
317            Some(context.tools.iter().map(|t| {
318                serde_json::json!({
319                    "type": "function",
320                    "function": {
321                        "name": t.name,
322                        "description": t.description,
323                        "parameters": t.parameters
324                    }
325                })
326            }).collect())
327        } else {
328            None
329        };
330
331        ProviderMessages {
332            messages,
333            system: None, // Included in messages
334            tools,
335        }
336    }
337
338    fn from_provider_format(&self, response: &ProviderResponse) -> NormalizedMessage {
339        NormalizedMessage {
340            id: Uuid::new_v4(),
341            role: MessageRole::Assistant,
342            content: response.content.clone(),
343            source_provider: "openai".to_string(),
344            source_model: Some(response.model.clone()),
345            attachments: vec![],
346            tool_calls: response.tool_calls.clone(),
347            token_count: response.usage.as_ref().map(|u| u.completion_tokens),
348            timestamp: Utc::now(),
349            metadata: HashMap::new(),
350        }
351    }
352
353    fn capabilities(&self) -> ProviderCapabilities {
354        ProviderCapabilities {
355            vision: true,
356            tools: true,
357            system_messages: true,
358            max_context: 128000,
359            streaming: true,
360        }
361    }
362
363    fn estimate_tokens(&self, messages: &[NormalizedMessage]) -> usize {
364        // Rough estimation: ~4 chars per token
365        messages.iter().map(|m| m.content.len() / 4).sum()
366    }
367}
368
369// ============================================================================
370// Anthropic Adapter
371// ============================================================================
372
373/// Anthropic Claude format adapter
374pub struct AnthropicAdapter;
375
376impl ProviderAdapter for AnthropicAdapter {
377    fn provider_name(&self) -> &str {
378        "anthropic"
379    }
380
381    fn to_provider_format(&self, context: &ConversationContext) -> ProviderMessages {
382        let mut messages: Vec<serde_json::Value> = vec![];
383
384        for msg in &context.messages {
385            // Anthropic uses "user" and "assistant" roles only
386            let role = match msg.role {
387                MessageRole::User | MessageRole::Tool => "user",
388                MessageRole::Assistant => "assistant",
389                MessageRole::System => continue, // System handled separately
390            };
391
392            let mut content_parts: Vec<serde_json::Value> = vec![];
393
394            // Add text content
395            content_parts.push(serde_json::json!({
396                "type": "text",
397                "text": msg.content
398            }));
399
400            // Add images
401            for attachment in &msg.attachments {
402                if attachment.attachment_type == AttachmentType::Image {
403                    content_parts.push(serde_json::json!({
404                        "type": "image",
405                        "source": {
406                            "type": "base64",
407                            "media_type": attachment.mime_type,
408                            "data": attachment.content
409                        }
410                    }));
411                }
412            }
413
414            messages.push(serde_json::json!({
415                "role": role,
416                "content": content_parts
417            }));
418        }
419
420        // Convert tools to Anthropic format
421        let tools = if !context.tools.is_empty() {
422            Some(context.tools.iter().map(|t| {
423                serde_json::json!({
424                    "name": t.name,
425                    "description": t.description,
426                    "input_schema": t.parameters
427                })
428            }).collect())
429        } else {
430            None
431        };
432
433        ProviderMessages {
434            messages,
435            system: context.system_prompt.clone(),
436            tools,
437        }
438    }
439
440    fn from_provider_format(&self, response: &ProviderResponse) -> NormalizedMessage {
441        NormalizedMessage {
442            id: Uuid::new_v4(),
443            role: MessageRole::Assistant,
444            content: response.content.clone(),
445            source_provider: "anthropic".to_string(),
446            source_model: Some(response.model.clone()),
447            attachments: vec![],
448            tool_calls: response.tool_calls.clone(),
449            token_count: response.usage.as_ref().map(|u| u.completion_tokens),
450            timestamp: Utc::now(),
451            metadata: HashMap::new(),
452        }
453    }
454
455    fn capabilities(&self) -> ProviderCapabilities {
456        ProviderCapabilities {
457            vision: true,
458            tools: true,
459            system_messages: true,
460            max_context: 200000,
461            streaming: true,
462        }
463    }
464
465    fn estimate_tokens(&self, messages: &[NormalizedMessage]) -> usize {
466        messages.iter().map(|m| m.content.len() / 4).sum()
467    }
468}
469
470// ============================================================================
471// Continuation Manager
472// ============================================================================
473
474/// Manages conversation continuation across providers
475pub struct ContinuationManager {
476    /// Provider adapters
477    adapters: HashMap<String, Box<dyn ProviderAdapter>>,
478    /// Active contexts
479    contexts: HashMap<Uuid, ConversationContext>,
480}
481
482impl ContinuationManager {
483    /// Create a new continuation manager
484    pub fn new() -> Self {
485        let mut adapters: HashMap<String, Box<dyn ProviderAdapter>> = HashMap::new();
486        adapters.insert("openai".to_string(), Box::new(OpenAIAdapter));
487        adapters.insert("anthropic".to_string(), Box::new(AnthropicAdapter));
488
489        Self {
490            adapters,
491            contexts: HashMap::new(),
492        }
493    }
494
495    /// Register a provider adapter
496    pub fn register_adapter(&mut self, adapter: Box<dyn ProviderAdapter>) {
497        self.adapters.insert(adapter.provider_name().to_string(), adapter);
498    }
499
500    /// Create a new conversation context
501    pub fn create_context(&mut self, title: &str, system_prompt: Option<&str>) -> Uuid {
502        let id = Uuid::new_v4();
503        let context = ConversationContext {
504            id,
505            title: title.to_string(),
506            system_prompt: system_prompt.map(String::from),
507            messages: vec![],
508            summary: None,
509            tools: vec![],
510            provider_history: vec![],
511            created_at: Utc::now(),
512            updated_at: Utc::now(),
513            metadata: HashMap::new(),
514        };
515        self.contexts.insert(id, context);
516        id
517    }
518
519    /// Add a message to the context
520    pub fn add_message(&mut self, context_id: Uuid, message: NormalizedMessage) -> bool {
521        if let Some(context) = self.contexts.get_mut(&context_id) {
522            context.messages.push(message);
523            context.updated_at = Utc::now();
524            true
525        } else {
526            false
527        }
528    }
529
530    /// Switch provider for a conversation
531    pub fn switch_provider(
532        &mut self,
533        context_id: Uuid,
534        to_provider: &str,
535        to_model: Option<&str>,
536        reason: Option<&str>,
537    ) -> Option<ProviderMessages> {
538        let context = self.contexts.get_mut(&context_id)?;
539        let adapter = self.adapters.get(to_provider)?;
540
541        // Record the switch
542        let last_provider = context.provider_history.last();
543        let switch = ProviderSwitch {
544            from_provider: last_provider.map(|p| p.to_provider.clone()).unwrap_or_default(),
545            from_model: last_provider.and_then(|p| p.to_model.clone()),
546            to_provider: to_provider.to_string(),
547            to_model: to_model.map(String::from),
548            reason: reason.map(String::from),
549            switched_at: Utc::now(),
550        };
551        context.provider_history.push(switch);
552        context.updated_at = Utc::now();
553
554        // Convert to new provider format
555        Some(adapter.to_provider_format(context))
556    }
557
558    /// Get context for a provider
559    pub fn get_provider_messages(&self, context_id: Uuid, provider: &str) -> Option<ProviderMessages> {
560        let context = self.contexts.get(&context_id)?;
561        let adapter = self.adapters.get(provider)?;
562        Some(adapter.to_provider_format(context))
563    }
564
565    /// Process a response from a provider
566    pub fn process_response(&mut self, context_id: Uuid, response: &ProviderResponse) -> Option<NormalizedMessage> {
567        let adapter = self.adapters.get(&response.provider)?;
568        let message = adapter.from_provider_format(response);
569        
570        if let Some(context) = self.contexts.get_mut(&context_id) {
571            context.messages.push(message.clone());
572            context.updated_at = Utc::now();
573        }
574
575        Some(message)
576    }
577
578    /// Get a context
579    pub fn get_context(&self, context_id: Uuid) -> Option<&ConversationContext> {
580        self.contexts.get(&context_id)
581    }
582
583    /// Estimate tokens for a context on a provider
584    pub fn estimate_tokens(&self, context_id: Uuid, provider: &str) -> Option<usize> {
585        let context = self.contexts.get(&context_id)?;
586        let adapter = self.adapters.get(provider)?;
587        Some(adapter.estimate_tokens(&context.messages))
588    }
589
590    /// Compress context by generating a summary
591    pub fn compress_context(&mut self, context_id: Uuid, summary_text: &str, topics: Vec<String>) -> bool {
592        if let Some(context) = self.contexts.get_mut(&context_id) {
593            let last_message_id = context.messages.last().map(|m| m.id).unwrap_or(Uuid::nil());
594            context.summary = Some(ConversationSummary {
595                text: summary_text.to_string(),
596                topics,
597                entities: vec![],
598                goals: vec![],
599                up_to_message_id: last_message_id,
600                generated_at: Utc::now(),
601            });
602            context.updated_at = Utc::now();
603            true
604        } else {
605            false
606        }
607    }
608}
609
610impl Default for ContinuationManager {
611    fn default() -> Self {
612        Self::new()
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[test]
621    fn test_create_context() {
622        let mut manager = ContinuationManager::new();
623        let id = manager.create_context("Test Conversation", Some("You are helpful."));
624        
625        let context = manager.get_context(id).unwrap();
626        assert_eq!(context.title, "Test Conversation");
627        assert_eq!(context.system_prompt.as_deref(), Some("You are helpful."));
628    }
629
630    #[test]
631    fn test_add_message() {
632        let mut manager = ContinuationManager::new();
633        let id = manager.create_context("Test", None);
634
635        let message = NormalizedMessage {
636            id: Uuid::new_v4(),
637            role: MessageRole::User,
638            content: "Hello!".to_string(),
639            source_provider: "openai".to_string(),
640            source_model: Some("gpt-4".to_string()),
641            attachments: vec![],
642            tool_calls: vec![],
643            token_count: None,
644            timestamp: Utc::now(),
645            metadata: HashMap::new(),
646        };
647
648        assert!(manager.add_message(id, message));
649        assert_eq!(manager.get_context(id).unwrap().messages.len(), 1);
650    }
651
652    #[test]
653    fn test_provider_switch() {
654        let mut manager = ContinuationManager::new();
655        let id = manager.create_context("Test", Some("System prompt"));
656
657        // Add a message
658        let message = NormalizedMessage {
659            id: Uuid::new_v4(),
660            role: MessageRole::User,
661            content: "Hello!".to_string(),
662            source_provider: "openai".to_string(),
663            source_model: None,
664            attachments: vec![],
665            tool_calls: vec![],
666            token_count: None,
667            timestamp: Utc::now(),
668            metadata: HashMap::new(),
669        };
670        manager.add_message(id, message);
671
672        // Switch to Anthropic
673        let messages = manager.switch_provider(id, "anthropic", Some("claude-sonnet-4-20250514"), Some("Better for writing"));
674        assert!(messages.is_some());
675
676        let context = manager.get_context(id).unwrap();
677        assert_eq!(context.provider_history.len(), 1);
678        assert_eq!(context.provider_history[0].to_provider, "anthropic");
679    }
680}