1use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use uuid::Uuid;
10use base64::Engine;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
14pub struct ContextId(pub String);
15
16impl ContextId {
17 pub fn new() -> Self {
19 Self(Uuid::new_v4().to_string())
20 }
21
22 pub fn from_content(content: &str) -> Self {
24 let mut hasher = Sha256::new();
25 hasher.update(content.as_bytes());
26 let hash = hasher.finalize();
27 Self(base64::engine::general_purpose::STANDARD.encode(&hash[..16]))
28 }
29
30 pub fn from_string(s: String) -> Self {
32 Self(s)
33 }
34
35 pub fn as_str(&self) -> &str {
36 &self.0
37 }
38}
39
40impl Default for ContextId {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl std::fmt::Display for ContextId {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "{}", self.0)
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum ContextDomain {
56 General,
58 Code,
60 Documentation,
62 Conversation,
64 Filesystem,
66 WebSearch,
68 Dataset,
70 Research,
72 Custom(String),
74}
75
76impl Default for ContextDomain {
77 fn default() -> Self {
78 Self::General
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ContextMetadata {
85 #[serde(default)]
87 pub source: String,
88
89 #[serde(default)]
91 pub tags: Vec<String>,
92
93 #[serde(default = "default_importance")]
95 pub importance: f32,
96
97 #[serde(default)]
99 pub verified: bool,
100
101 #[serde(default)]
103 pub screening_status: ScreeningStatus,
104
105 #[serde(default)]
107 pub custom: std::collections::HashMap<String, serde_json::Value>,
108}
109
110fn default_importance() -> f32 {
111 1.0
112}
113
114impl Default for ContextMetadata {
115 fn default() -> Self {
116 Self {
117 source: String::new(),
118 tags: Vec::new(),
119 importance: 1.0,
120 verified: false,
121 screening_status: ScreeningStatus::Unscreened,
122 custom: std::collections::HashMap::new(),
123 }
124 }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
129#[serde(rename_all = "snake_case")]
130pub enum ScreeningStatus {
131 #[default]
133 Unscreened,
134 Safe,
136 Flagged,
138 Blocked,
140 Pending,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct Context {
152 pub id: ContextId,
154
155 pub content: String,
157
158 pub domain: ContextDomain,
160
161 pub created_at: DateTime<Utc>,
163
164 pub accessed_at: DateTime<Utc>,
166
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub expires_at: Option<DateTime<Utc>>,
170
171 pub metadata: ContextMetadata,
173
174 #[serde(skip_serializing_if = "Option::is_none")]
176 pub embedding: Option<Vec<f32>>,
177}
178
179impl Context {
180 pub fn new(content: impl Into<String>, domain: ContextDomain) -> Self {
182 let content = content.into();
183 let now = Utc::now();
184 Self {
185 id: ContextId::from_content(&content),
186 content,
187 domain,
188 created_at: now,
189 accessed_at: now,
190 expires_at: None,
191 metadata: ContextMetadata::default(),
192 embedding: None,
193 }
194 }
195
196 pub fn with_id(mut self, id: ContextId) -> Self {
198 self.id = id;
199 self
200 }
201
202 pub fn with_metadata(mut self, metadata: ContextMetadata) -> Self {
204 self.metadata = metadata;
205 self
206 }
207
208 pub fn with_source(mut self, source: impl Into<String>) -> Self {
210 self.metadata.source = source.into();
211 self
212 }
213
214 pub fn with_importance(mut self, importance: f32) -> Self {
216 self.metadata.importance = importance.clamp(0.0, 1.0);
217 self
218 }
219
220 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
222 self.metadata.tags = tags;
223 self
224 }
225
226 pub fn with_expiration(mut self, expires_at: DateTime<Utc>) -> Self {
228 self.expires_at = Some(expires_at);
229 self
230 }
231
232 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
234 self.embedding = Some(embedding);
235 self
236 }
237
238 pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
240 self.expires_at = Some(Utc::now() + Duration::from_std(ttl).unwrap_or(Duration::hours(24)));
241 self
242 }
243
244 pub fn is_expired(&self) -> bool {
246 self.expires_at
247 .map(|exp| Utc::now() > exp)
248 .unwrap_or(false)
249 }
250
251 pub fn age_seconds(&self) -> i64 {
253 (Utc::now() - self.created_at).num_seconds()
254 }
255
256 pub fn age_hours(&self) -> f64 {
258 self.age_seconds() as f64 / 3600.0
259 }
260
261 pub fn mark_accessed(&mut self) {
263 self.accessed_at = Utc::now();
264 }
265
266 pub fn is_safe(&self) -> bool {
268 matches!(
269 self.metadata.screening_status,
270 ScreeningStatus::Safe | ScreeningStatus::Unscreened
271 )
272 }
273}
274
275#[derive(Debug, Clone, Default)]
277pub struct ContextQuery {
278 pub query: Option<String>,
280 pub domain_filter: Option<ContextDomain>,
282 pub tag_filter: Option<Vec<String>>,
284 pub source_filter: Option<String>,
286 pub min_importance: Option<f32>,
288 pub max_age_seconds: Option<i64>,
290 pub verified_only: bool,
292 pub limit: usize,
294}
295
296impl ContextQuery {
297 pub fn new() -> Self {
298 Self {
299 limit: 10,
300 ..Default::default()
301 }
302 }
303
304 pub fn with_text(mut self, query: impl Into<String>) -> Self {
305 self.query = Some(query.into());
306 self
307 }
308
309 pub fn with_domain(mut self, domain: ContextDomain) -> Self {
310 self.domain_filter = Some(domain);
311 self
312 }
313
314 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
315 self.tag_filter = Some(tags);
316 self
317 }
318
319 pub fn with_min_importance(mut self, importance: f32) -> Self {
320 self.min_importance = Some(importance);
321 self
322 }
323
324 pub fn with_max_age(mut self, seconds: i64) -> Self {
325 self.max_age_seconds = Some(seconds);
326 self
327 }
328
329 pub fn with_max_age_hours(mut self, hours: i64) -> Self {
330 self.max_age_seconds = Some(hours * 3600);
331 self
332 }
333
334 pub fn with_tag(mut self, tag: String) -> Self {
335 if self.tag_filter.is_none() {
336 self.tag_filter = Some(Vec::new());
337 }
338 self.tag_filter.as_mut().unwrap().push(tag);
339 self
340 }
341
342 pub fn verified_only(mut self) -> Self {
343 self.verified_only = true;
344 self
345 }
346
347 pub fn with_limit(mut self, limit: usize) -> Self {
348 self.limit = limit;
349 self
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_context_creation() {
359 let ctx = Context::new("Test content", ContextDomain::Code);
360 assert!(!ctx.content.is_empty());
361 assert_eq!(ctx.domain, ContextDomain::Code);
362 assert!(!ctx.is_expired());
363 }
364
365 #[test]
366 fn test_context_id_from_content() {
367 let id1 = ContextId::from_content("hello world");
368 let id2 = ContextId::from_content("hello world");
369 let id3 = ContextId::from_content("different content");
370
371 assert_eq!(id1, id2);
372 assert_ne!(id1, id3);
373 }
374
375 #[test]
376 fn test_context_age() {
377 let ctx = Context::new("Test", ContextDomain::General);
378 assert!(ctx.age_seconds() >= 0);
379 assert!(ctx.age_hours() >= 0.0);
380 }
381
382 #[test]
383 fn test_context_query_builder() {
384 let query = ContextQuery::new()
385 .with_text("search term")
386 .with_domain(ContextDomain::Code)
387 .with_min_importance(0.5)
388 .with_limit(20);
389
390 assert_eq!(query.query, Some("search term".to_string()));
391 assert_eq!(query.domain_filter, Some(ContextDomain::Code));
392 assert_eq!(query.min_importance, Some(0.5));
393 assert_eq!(query.limit, 20);
394 }
395}