context_mcp/
context.rs

1//! Context data structures and core types
2//!
3//! Inspired by memory-gate's LearningContext pattern with enhancements
4//! for temporal reasoning and MCP integration.
5
6use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use uuid::Uuid;
10use base64::Engine;
11
12/// Unique identifier for a context entry
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
14pub struct ContextId(pub String);
15
16impl ContextId {
17    /// Generate a new random context ID
18    pub fn new() -> Self {
19        Self(Uuid::new_v4().to_string())
20    }
21
22    /// Generate a deterministic ID from content hash
23    pub fn from_content(content: &str) -> Self {
24        let mut hasher = Sha256::new();
25        hasher.update(content.as_bytes());
26        let hash = hasher.finalize();
27        Self(base64::engine::general_purpose::STANDARD.encode(&hash[..16]))
28    }
29
30    /// Create from a string
31    pub fn from_string(s: String) -> Self {
32        Self(s)
33    }
34
35    pub fn as_str(&self) -> &str {
36        &self.0
37    }
38}
39
40impl Default for ContextId {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl std::fmt::Display for ContextId {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        write!(f, "{}", self.0)
49    }
50}
51
52/// Domain classification for context entries
53#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum ContextDomain {
56    /// General purpose context
57    General,
58    /// Code and programming related
59    Code,
60    /// Documentation and technical writing
61    Documentation,
62    /// Conversation history
63    Conversation,
64    /// File system operations
65    Filesystem,
66    /// Web search results
67    WebSearch,
68    /// Dataset information
69    Dataset,
70    /// Research and papers
71    Research,
72    /// Custom domain with identifier
73    Custom(String),
74}
75
76impl Default for ContextDomain {
77    fn default() -> Self {
78        Self::General
79    }
80}
81
82/// Metadata associated with a context entry
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ContextMetadata {
85    /// Source of the context (e.g., "user", "web", "file")
86    #[serde(default)]
87    pub source: String,
88
89    /// Tags for categorization
90    #[serde(default)]
91    pub tags: Vec<String>,
92
93    /// Importance score (0.0 to 1.0)
94    #[serde(default = "default_importance")]
95    pub importance: f32,
96
97    /// Whether this context has been verified/screened
98    #[serde(default)]
99    pub verified: bool,
100
101    /// Security screening status
102    #[serde(default)]
103    pub screening_status: ScreeningStatus,
104
105    /// Custom key-value pairs
106    #[serde(default)]
107    pub custom: std::collections::HashMap<String, serde_json::Value>,
108}
109
110fn default_importance() -> f32 {
111    1.0
112}
113
114impl Default for ContextMetadata {
115    fn default() -> Self {
116        Self {
117            source: String::new(),
118            tags: Vec::new(),
119            importance: 1.0,
120            verified: false,
121            screening_status: ScreeningStatus::Unscreened,
122            custom: std::collections::HashMap::new(),
123        }
124    }
125}
126
127/// Security screening status for context entries
128#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
129#[serde(rename_all = "snake_case")]
130pub enum ScreeningStatus {
131    /// Not yet screened
132    #[default]
133    Unscreened,
134    /// Screened and safe
135    Safe,
136    /// Screened and flagged for review
137    Flagged,
138    /// Screened and blocked
139    Blocked,
140    /// Screening in progress
141    Pending,
142}
143
144/// A context entry for storage and retrieval
145///
146/// Inspired by memory-gate's LearningContext with additions for:
147/// - Temporal reasoning (created_at, accessed_at, expires_at)
148/// - Security screening integration
149/// - RAG-optimized fields
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct Context {
152    /// Unique identifier
153    pub id: ContextId,
154
155    /// Main content of the context
156    pub content: String,
157
158    /// Domain classification
159    pub domain: ContextDomain,
160
161    /// When this context was created
162    pub created_at: DateTime<Utc>,
163
164    /// When this context was last accessed
165    pub accessed_at: DateTime<Utc>,
166
167    /// Optional expiration time
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub expires_at: Option<DateTime<Utc>>,
170
171    /// Associated metadata
172    pub metadata: ContextMetadata,
173
174    /// Optional embedding vector for similarity search
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub embedding: Option<Vec<f32>>,
177}
178
179impl Context {
180    /// Create a new context entry
181    pub fn new(content: impl Into<String>, domain: ContextDomain) -> Self {
182        let content = content.into();
183        let now = Utc::now();
184        Self {
185            id: ContextId::from_content(&content),
186            content,
187            domain,
188            created_at: now,
189            accessed_at: now,
190            expires_at: None,
191            metadata: ContextMetadata::default(),
192            embedding: None,
193        }
194    }
195
196    /// Create with a specific ID
197    pub fn with_id(mut self, id: ContextId) -> Self {
198        self.id = id;
199        self
200    }
201
202    /// Set metadata
203    pub fn with_metadata(mut self, metadata: ContextMetadata) -> Self {
204        self.metadata = metadata;
205        self
206    }
207
208    /// Set source in metadata
209    pub fn with_source(mut self, source: impl Into<String>) -> Self {
210        self.metadata.source = source.into();
211        self
212    }
213
214    /// Set importance
215    pub fn with_importance(mut self, importance: f32) -> Self {
216        self.metadata.importance = importance.clamp(0.0, 1.0);
217        self
218    }
219
220    /// Add tags
221    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
222        self.metadata.tags = tags;
223        self
224    }
225
226    /// Set expiration
227    pub fn with_expiration(mut self, expires_at: DateTime<Utc>) -> Self {
228        self.expires_at = Some(expires_at);
229        self
230    }
231
232    /// Set embedding vector
233    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
234        self.embedding = Some(embedding);
235        self
236    }
237
238    /// Set TTL (time to live)
239    pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
240        self.expires_at = Some(Utc::now() + Duration::from_std(ttl).unwrap_or(Duration::hours(24)));
241        self
242    }
243
244    /// Check if context has expired
245    pub fn is_expired(&self) -> bool {
246        self.expires_at
247            .map(|exp| Utc::now() > exp)
248            .unwrap_or(false)
249    }
250
251    /// Get age in seconds
252    pub fn age_seconds(&self) -> i64 {
253        (Utc::now() - self.created_at).num_seconds()
254    }
255
256    /// Get age in hours (useful for temporal reasoning)
257    pub fn age_hours(&self) -> f64 {
258        self.age_seconds() as f64 / 3600.0
259    }
260
261    /// Mark as accessed (updates accessed_at)
262    pub fn mark_accessed(&mut self) {
263        self.accessed_at = Utc::now();
264    }
265
266    /// Check if context is safe to use (screened)
267    pub fn is_safe(&self) -> bool {
268        matches!(
269            self.metadata.screening_status,
270            ScreeningStatus::Safe | ScreeningStatus::Unscreened
271        )
272    }
273}
274
275/// Builder for creating context queries
276#[derive(Debug, Clone, Default)]
277pub struct ContextQuery {
278    /// Text query for similarity search
279    pub query: Option<String>,
280    /// Filter by domain
281    pub domain_filter: Option<ContextDomain>,
282    /// Filter by tags (any match)
283    pub tag_filter: Option<Vec<String>>,
284    /// Filter by source
285    pub source_filter: Option<String>,
286    /// Minimum importance threshold
287    pub min_importance: Option<f32>,
288    /// Maximum age in seconds
289    pub max_age_seconds: Option<i64>,
290    /// Only return verified/screened context
291    pub verified_only: bool,
292    /// Maximum results to return
293    pub limit: usize,
294}
295
296impl ContextQuery {
297    pub fn new() -> Self {
298        Self {
299            limit: 10,
300            ..Default::default()
301        }
302    }
303
304    pub fn with_text(mut self, query: impl Into<String>) -> Self {
305        self.query = Some(query.into());
306        self
307    }
308
309    pub fn with_domain(mut self, domain: ContextDomain) -> Self {
310        self.domain_filter = Some(domain);
311        self
312    }
313
314    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
315        self.tag_filter = Some(tags);
316        self
317    }
318
319    pub fn with_min_importance(mut self, importance: f32) -> Self {
320        self.min_importance = Some(importance);
321        self
322    }
323
324    pub fn with_max_age(mut self, seconds: i64) -> Self {
325        self.max_age_seconds = Some(seconds);
326        self
327    }
328
329    pub fn with_max_age_hours(mut self, hours: i64) -> Self {
330        self.max_age_seconds = Some(hours * 3600);
331        self
332    }
333
334    pub fn with_tag(mut self, tag: String) -> Self {
335        if self.tag_filter.is_none() {
336            self.tag_filter = Some(Vec::new());
337        }
338        self.tag_filter.as_mut().unwrap().push(tag);
339        self
340    }
341
342    pub fn verified_only(mut self) -> Self {
343        self.verified_only = true;
344        self
345    }
346
347    pub fn with_limit(mut self, limit: usize) -> Self {
348        self.limit = limit;
349        self
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_context_creation() {
359        let ctx = Context::new("Test content", ContextDomain::Code);
360        assert!(!ctx.content.is_empty());
361        assert_eq!(ctx.domain, ContextDomain::Code);
362        assert!(!ctx.is_expired());
363    }
364
365    #[test]
366    fn test_context_id_from_content() {
367        let id1 = ContextId::from_content("hello world");
368        let id2 = ContextId::from_content("hello world");
369        let id3 = ContextId::from_content("different content");
370
371        assert_eq!(id1, id2);
372        assert_ne!(id1, id3);
373    }
374
375    #[test]
376    fn test_context_age() {
377        let ctx = Context::new("Test", ContextDomain::General);
378        assert!(ctx.age_seconds() >= 0);
379        assert!(ctx.age_hours() >= 0.0);
380    }
381
382    #[test]
383    fn test_context_query_builder() {
384        let query = ContextQuery::new()
385            .with_text("search term")
386            .with_domain(ContextDomain::Code)
387            .with_min_importance(0.5)
388            .with_limit(20);
389
390        assert_eq!(query.query, Some("search term".to_string()));
391        assert_eq!(query.domain_filter, Some(ContextDomain::Code));
392        assert_eq!(query.min_importance, Some(0.5));
393        assert_eq!(query.limit, 20);
394    }
395}