1use base64::Engine;
7use chrono::{DateTime, Duration, Utc};
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
14pub struct ContextId(pub String);
15
16impl ContextId {
17 pub fn new() -> Self {
19 Self(Uuid::new_v4().to_string())
20 }
21
22 pub fn from_content(content: &str) -> Self {
24 let mut hasher = Sha256::new();
25 hasher.update(content.as_bytes());
26 let hash = hasher.finalize();
27 Self(base64::engine::general_purpose::STANDARD.encode(&hash[..16]))
28 }
29
30 pub fn from_string(s: String) -> Self {
32 Self(s)
33 }
34
35 pub fn as_str(&self) -> &str {
36 &self.0
37 }
38}
39
40impl Default for ContextId {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl std::fmt::Display for ContextId {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "{}", self.0)
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum ContextDomain {
56 General,
58 Code,
60 Documentation,
62 Conversation,
64 Filesystem,
66 WebSearch,
68 Dataset,
70 Research,
72 Custom(String),
74}
75
76#[allow(clippy::derivable_impls)]
77impl Default for ContextDomain {
78 fn default() -> Self {
79 Self::General
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct ContextMetadata {
86 #[serde(default)]
88 pub source: String,
89
90 #[serde(default)]
92 pub tags: Vec<String>,
93
94 #[serde(default = "default_importance")]
96 pub importance: f32,
97
98 #[serde(default)]
100 pub verified: bool,
101
102 #[serde(default)]
104 pub screening_status: ScreeningStatus,
105
106 #[serde(default)]
108 pub custom: std::collections::HashMap<String, serde_json::Value>,
109}
110
111fn default_importance() -> f32 {
112 1.0
113}
114
115impl Default for ContextMetadata {
116 fn default() -> Self {
117 Self {
118 source: String::new(),
119 tags: Vec::new(),
120 importance: 1.0,
121 verified: false,
122 screening_status: ScreeningStatus::Unscreened,
123 custom: std::collections::HashMap::new(),
124 }
125 }
126}
127
128#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
130#[serde(rename_all = "snake_case")]
131pub enum ScreeningStatus {
132 #[default]
134 Unscreened,
135 Safe,
137 Flagged,
139 Blocked,
141 Pending,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct Context {
153 pub id: ContextId,
155
156 pub content: String,
158
159 pub domain: ContextDomain,
161
162 pub created_at: DateTime<Utc>,
164
165 pub accessed_at: DateTime<Utc>,
167
168 #[serde(skip_serializing_if = "Option::is_none")]
170 pub expires_at: Option<DateTime<Utc>>,
171
172 pub metadata: ContextMetadata,
174
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub embedding: Option<Vec<f32>>,
178}
179
180impl Context {
181 pub fn new(content: impl Into<String>, domain: ContextDomain) -> Self {
183 let content = content.into();
184 let now = Utc::now();
185 Self {
186 id: ContextId::from_content(&content),
187 content,
188 domain,
189 created_at: now,
190 accessed_at: now,
191 expires_at: None,
192 metadata: ContextMetadata::default(),
193 embedding: None,
194 }
195 }
196
197 pub fn with_id(mut self, id: ContextId) -> Self {
199 self.id = id;
200 self
201 }
202
203 pub fn with_metadata(mut self, metadata: ContextMetadata) -> Self {
205 self.metadata = metadata;
206 self
207 }
208
209 pub fn with_source(mut self, source: impl Into<String>) -> Self {
211 self.metadata.source = source.into();
212 self
213 }
214
215 pub fn with_importance(mut self, importance: f32) -> Self {
217 self.metadata.importance = importance.clamp(0.0, 1.0);
218 self
219 }
220
221 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
223 self.metadata.tags = tags;
224 self
225 }
226
227 pub fn with_expiration(mut self, expires_at: DateTime<Utc>) -> Self {
229 self.expires_at = Some(expires_at);
230 self
231 }
232
233 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
235 self.embedding = Some(embedding);
236 self
237 }
238
239 pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
241 self.expires_at = Some(Utc::now() + Duration::from_std(ttl).unwrap_or(Duration::hours(24)));
242 self
243 }
244
245 pub fn is_expired(&self) -> bool {
247 self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
248 }
249
250 pub fn age_seconds(&self) -> i64 {
252 (Utc::now() - self.created_at).num_seconds()
253 }
254
255 pub fn age_hours(&self) -> f64 {
257 self.age_seconds() as f64 / 3600.0
258 }
259
260 pub fn mark_accessed(&mut self) {
262 self.accessed_at = Utc::now();
263 }
264
265 pub fn is_safe(&self) -> bool {
267 matches!(
268 self.metadata.screening_status,
269 ScreeningStatus::Safe | ScreeningStatus::Unscreened
270 )
271 }
272}
273
274#[derive(Debug, Clone, Default)]
276pub struct ContextQuery {
277 pub query: Option<String>,
279 pub domain_filter: Option<ContextDomain>,
281 pub tag_filter: Option<Vec<String>>,
283 pub source_filter: Option<String>,
285 pub min_importance: Option<f32>,
287 pub max_age_seconds: Option<i64>,
289 pub verified_only: bool,
291 pub limit: usize,
293}
294
295impl ContextQuery {
296 pub fn new() -> Self {
297 Self {
298 limit: 10,
299 ..Default::default()
300 }
301 }
302
303 pub fn with_text(mut self, query: impl Into<String>) -> Self {
304 self.query = Some(query.into());
305 self
306 }
307
308 pub fn with_domain(mut self, domain: ContextDomain) -> Self {
309 self.domain_filter = Some(domain);
310 self
311 }
312
313 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
314 self.tag_filter = Some(tags);
315 self
316 }
317
318 pub fn with_min_importance(mut self, importance: f32) -> Self {
319 self.min_importance = Some(importance);
320 self
321 }
322
323 pub fn with_max_age(mut self, seconds: i64) -> Self {
324 self.max_age_seconds = Some(seconds);
325 self
326 }
327
328 pub fn with_max_age_hours(mut self, hours: i64) -> Self {
329 self.max_age_seconds = Some(hours * 3600);
330 self
331 }
332
333 pub fn with_tag(mut self, tag: String) -> Self {
334 if self.tag_filter.is_none() {
335 self.tag_filter = Some(Vec::new());
336 }
337 self.tag_filter.as_mut().unwrap().push(tag);
338 self
339 }
340
341 pub fn verified_only(mut self) -> Self {
342 self.verified_only = true;
343 self
344 }
345
346 pub fn with_limit(mut self, limit: usize) -> Self {
347 self.limit = limit;
348 self
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_context_creation() {
358 let ctx = Context::new("Test content", ContextDomain::Code);
359 assert!(!ctx.content.is_empty());
360 assert_eq!(ctx.domain, ContextDomain::Code);
361 assert!(!ctx.is_expired());
362 }
363
364 #[test]
365 fn test_context_id_from_content() {
366 let id1 = ContextId::from_content("hello world");
367 let id2 = ContextId::from_content("hello world");
368 let id3 = ContextId::from_content("different content");
369
370 assert_eq!(id1, id2);
371 assert_ne!(id1, id3);
372 }
373
374 #[test]
375 fn test_context_age() {
376 let ctx = Context::new("Test", ContextDomain::General);
377 assert!(ctx.age_seconds() >= 0);
378 assert!(ctx.age_hours() >= 0.0);
379 }
380
381 #[test]
382 fn test_context_query_builder() {
383 let query = ContextQuery::new()
384 .with_text("search term")
385 .with_domain(ContextDomain::Code)
386 .with_min_importance(0.5)
387 .with_limit(20);
388
389 assert_eq!(query.query, Some("search term".to_string()));
390 assert_eq!(query.domain_filter, Some(ContextDomain::Code));
391 assert_eq!(query.min_importance, Some(0.5));
392 assert_eq!(query.limit, 20);
393 }
394}