context_mcp/
context.rs

1//! Context data structures and core types
2//!
3//! Inspired by memory-gate's LearningContext pattern with enhancements
4//! for temporal reasoning and MCP integration.
5
6use base64::Engine;
7use chrono::{DateTime, Duration, Utc};
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use uuid::Uuid;
11
12/// Unique identifier for a context entry
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
14pub struct ContextId(pub String);
15
16impl ContextId {
17    /// Generate a new random context ID
18    pub fn new() -> Self {
19        Self(Uuid::new_v4().to_string())
20    }
21
22    /// Generate a deterministic ID from content hash
23    pub fn from_content(content: &str) -> Self {
24        let mut hasher = Sha256::new();
25        hasher.update(content.as_bytes());
26        let hash = hasher.finalize();
27        Self(base64::engine::general_purpose::STANDARD.encode(&hash[..16]))
28    }
29
30    /// Create from a string
31    pub fn from_string(s: String) -> Self {
32        Self(s)
33    }
34
35    pub fn as_str(&self) -> &str {
36        &self.0
37    }
38}
39
40impl Default for ContextId {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl std::fmt::Display for ContextId {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        write!(f, "{}", self.0)
49    }
50}
51
52/// Domain classification for context entries
53#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum ContextDomain {
56    /// General purpose context
57    General,
58    /// Code and programming related
59    Code,
60    /// Documentation and technical writing
61    Documentation,
62    /// Conversation history
63    Conversation,
64    /// File system operations
65    Filesystem,
66    /// Web search results
67    WebSearch,
68    /// Dataset information
69    Dataset,
70    /// Research and papers
71    Research,
72    /// Custom domain with identifier
73    Custom(String),
74}
75
76#[allow(clippy::derivable_impls)]
77impl Default for ContextDomain {
78    fn default() -> Self {
79        Self::General
80    }
81}
82
83/// Metadata associated with a context entry
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct ContextMetadata {
86    /// Source of the context (e.g., "user", "web", "file")
87    #[serde(default)]
88    pub source: String,
89
90    /// Tags for categorization
91    #[serde(default)]
92    pub tags: Vec<String>,
93
94    /// Importance score (0.0 to 1.0)
95    #[serde(default = "default_importance")]
96    pub importance: f32,
97
98    /// Whether this context has been verified/screened
99    #[serde(default)]
100    pub verified: bool,
101
102    /// Security screening status
103    #[serde(default)]
104    pub screening_status: ScreeningStatus,
105
106    /// Custom key-value pairs
107    #[serde(default)]
108    pub custom: std::collections::HashMap<String, serde_json::Value>,
109}
110
111fn default_importance() -> f32 {
112    1.0
113}
114
115impl Default for ContextMetadata {
116    fn default() -> Self {
117        Self {
118            source: String::new(),
119            tags: Vec::new(),
120            importance: 1.0,
121            verified: false,
122            screening_status: ScreeningStatus::Unscreened,
123            custom: std::collections::HashMap::new(),
124        }
125    }
126}
127
128/// Security screening status for context entries
129#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
130#[serde(rename_all = "snake_case")]
131pub enum ScreeningStatus {
132    /// Not yet screened
133    #[default]
134    Unscreened,
135    /// Screened and safe
136    Safe,
137    /// Screened and flagged for review
138    Flagged,
139    /// Screened and blocked
140    Blocked,
141    /// Screening in progress
142    Pending,
143}
144
145/// A context entry for storage and retrieval
146///
147/// Inspired by memory-gate's LearningContext with additions for:
148/// - Temporal reasoning (created_at, accessed_at, expires_at)
149/// - Security screening integration
150/// - RAG-optimized fields
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct Context {
153    /// Unique identifier
154    pub id: ContextId,
155
156    /// Main content of the context
157    pub content: String,
158
159    /// Domain classification
160    pub domain: ContextDomain,
161
162    /// When this context was created
163    pub created_at: DateTime<Utc>,
164
165    /// When this context was last accessed
166    pub accessed_at: DateTime<Utc>,
167
168    /// Optional expiration time
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub expires_at: Option<DateTime<Utc>>,
171
172    /// Associated metadata
173    pub metadata: ContextMetadata,
174
175    /// Optional embedding vector for similarity search
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub embedding: Option<Vec<f32>>,
178}
179
180impl Context {
181    /// Create a new context entry
182    pub fn new(content: impl Into<String>, domain: ContextDomain) -> Self {
183        let content = content.into();
184        let now = Utc::now();
185        Self {
186            id: ContextId::from_content(&content),
187            content,
188            domain,
189            created_at: now,
190            accessed_at: now,
191            expires_at: None,
192            metadata: ContextMetadata::default(),
193            embedding: None,
194        }
195    }
196
197    /// Create with a specific ID
198    pub fn with_id(mut self, id: ContextId) -> Self {
199        self.id = id;
200        self
201    }
202
203    /// Set metadata
204    pub fn with_metadata(mut self, metadata: ContextMetadata) -> Self {
205        self.metadata = metadata;
206        self
207    }
208
209    /// Set source in metadata
210    pub fn with_source(mut self, source: impl Into<String>) -> Self {
211        self.metadata.source = source.into();
212        self
213    }
214
215    /// Set importance
216    pub fn with_importance(mut self, importance: f32) -> Self {
217        self.metadata.importance = importance.clamp(0.0, 1.0);
218        self
219    }
220
221    /// Add tags
222    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
223        self.metadata.tags = tags;
224        self
225    }
226
227    /// Set expiration
228    pub fn with_expiration(mut self, expires_at: DateTime<Utc>) -> Self {
229        self.expires_at = Some(expires_at);
230        self
231    }
232
233    /// Set embedding vector
234    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
235        self.embedding = Some(embedding);
236        self
237    }
238
239    /// Set TTL (time to live)
240    pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
241        self.expires_at = Some(Utc::now() + Duration::from_std(ttl).unwrap_or(Duration::hours(24)));
242        self
243    }
244
245    /// Check if context has expired
246    pub fn is_expired(&self) -> bool {
247        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
248    }
249
250    /// Get age in seconds
251    pub fn age_seconds(&self) -> i64 {
252        (Utc::now() - self.created_at).num_seconds()
253    }
254
255    /// Get age in hours (useful for temporal reasoning)
256    pub fn age_hours(&self) -> f64 {
257        self.age_seconds() as f64 / 3600.0
258    }
259
260    /// Mark as accessed (updates accessed_at)
261    pub fn mark_accessed(&mut self) {
262        self.accessed_at = Utc::now();
263    }
264
265    /// Check if context is safe to use (screened)
266    pub fn is_safe(&self) -> bool {
267        matches!(
268            self.metadata.screening_status,
269            ScreeningStatus::Safe | ScreeningStatus::Unscreened
270        )
271    }
272}
273
274/// Builder for creating context queries
275#[derive(Debug, Clone, Default)]
276pub struct ContextQuery {
277    /// Text query for similarity search
278    pub query: Option<String>,
279    /// Filter by domain
280    pub domain_filter: Option<ContextDomain>,
281    /// Filter by tags (any match)
282    pub tag_filter: Option<Vec<String>>,
283    /// Filter by source
284    pub source_filter: Option<String>,
285    /// Minimum importance threshold
286    pub min_importance: Option<f32>,
287    /// Maximum age in seconds
288    pub max_age_seconds: Option<i64>,
289    /// Only return verified/screened context
290    pub verified_only: bool,
291    /// Maximum results to return
292    pub limit: usize,
293}
294
295impl ContextQuery {
296    pub fn new() -> Self {
297        Self {
298            limit: 10,
299            ..Default::default()
300        }
301    }
302
303    pub fn with_text(mut self, query: impl Into<String>) -> Self {
304        self.query = Some(query.into());
305        self
306    }
307
308    pub fn with_domain(mut self, domain: ContextDomain) -> Self {
309        self.domain_filter = Some(domain);
310        self
311    }
312
313    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
314        self.tag_filter = Some(tags);
315        self
316    }
317
318    pub fn with_min_importance(mut self, importance: f32) -> Self {
319        self.min_importance = Some(importance);
320        self
321    }
322
323    pub fn with_max_age(mut self, seconds: i64) -> Self {
324        self.max_age_seconds = Some(seconds);
325        self
326    }
327
328    pub fn with_max_age_hours(mut self, hours: i64) -> Self {
329        self.max_age_seconds = Some(hours * 3600);
330        self
331    }
332
333    pub fn with_tag(mut self, tag: String) -> Self {
334        if self.tag_filter.is_none() {
335            self.tag_filter = Some(Vec::new());
336        }
337        self.tag_filter.as_mut().unwrap().push(tag);
338        self
339    }
340
341    pub fn verified_only(mut self) -> Self {
342        self.verified_only = true;
343        self
344    }
345
346    pub fn with_limit(mut self, limit: usize) -> Self {
347        self.limit = limit;
348        self
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_context_creation() {
358        let ctx = Context::new("Test content", ContextDomain::Code);
359        assert!(!ctx.content.is_empty());
360        assert_eq!(ctx.domain, ContextDomain::Code);
361        assert!(!ctx.is_expired());
362    }
363
364    #[test]
365    fn test_context_id_from_content() {
366        let id1 = ContextId::from_content("hello world");
367        let id2 = ContextId::from_content("hello world");
368        let id3 = ContextId::from_content("different content");
369
370        assert_eq!(id1, id2);
371        assert_ne!(id1, id3);
372    }
373
374    #[test]
375    fn test_context_age() {
376        let ctx = Context::new("Test", ContextDomain::General);
377        assert!(ctx.age_seconds() >= 0);
378        assert!(ctx.age_hours() >= 0.0);
379    }
380
381    #[test]
382    fn test_context_query_builder() {
383        let query = ContextQuery::new()
384            .with_text("search term")
385            .with_domain(ContextDomain::Code)
386            .with_min_importance(0.5)
387            .with_limit(20);
388
389        assert_eq!(query.query, Some("search term".to_string()));
390        assert_eq!(query.domain_filter, Some(ContextDomain::Code));
391        assert_eq!(query.min_importance, Some(0.5));
392        assert_eq!(query.limit, 20);
393    }
394}