1use std::collections::HashMap;
7use std::sync::Arc;
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value as JsonValue;
11
12use super::security::SecurityValidator;
13use crate::matrixrpc::{ErrorCode, JsonRpcError, JsonRpcId, JsonRpcResponse, ServiceId};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ContextOperation {
19 Get,
21
22 Set,
24
25 Delete,
27
28 List,
30
31 Exists,
33
34 Clear,
36}
37
38impl Default for ContextOperation {
39 fn default() -> Self {
40 Self::Get
41 }
42}
43
44impl ContextOperation {
45 pub fn as_str(&self) -> &'static str {
47 match self {
48 ContextOperation::Get => "get",
49 ContextOperation::Set => "set",
50 ContextOperation::Delete => "delete",
51 ContextOperation::List => "list",
52 ContextOperation::Exists => "exists",
53 ContextOperation::Clear => "clear",
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ContextCallbackRequest {
61 pub request_id: String,
63
64 pub service_id: ServiceId,
66
67 pub token: String,
69
70 #[serde(default)]
72 pub operation: ContextOperation,
73
74 #[serde(default)]
76 pub key: Option<String>,
77
78 #[serde(default)]
80 pub value: Option<JsonValue>,
81
82 #[serde(default)]
84 pub namespace: Option<String>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ContextCallbackResult {
90 pub operation: String,
92
93 #[serde(default)]
95 pub key: Option<String>,
96
97 #[serde(default)]
99 pub value: Option<JsonValue>,
100
101 #[serde(default)]
103 pub keys: Vec<String>,
104
105 #[serde(default)]
107 pub exists: Option<bool>,
108
109 pub status: String,
111
112 #[serde(default)]
114 pub metadata: JsonValue,
115}
116
117#[derive(Debug, thiserror::Error)]
119pub enum ContextCallbackError {
120 #[error("Security validation failed: {0}")]
122 SecurityFailed(String),
123
124 #[error("Context key '{0}' not found")]
126 KeyNotFound(String),
127
128 #[error("Context key '{0}' already exists")]
130 KeyExists(String),
131
132 #[error("Invalid context operation: {0}")]
134 InvalidOperation(String),
135
136 #[error("Missing key for context operation")]
138 MissingKey,
139
140 #[error("Missing value for Set operation")]
142 MissingValue,
143
144 #[error("Namespace '{0}' is not accessible")]
146 NamespaceNotAccessible(String),
147
148 #[error("Context is read-only, cannot perform {0} operation")]
150 ReadOnly(String),
151
152 #[error("Internal error: {0}")]
154 Internal(String),
155}
156
157#[derive(Debug, Clone)]
159pub struct ContextNamespaceConfig {
160 pub public: Vec<String>,
162
163 pub service_namespaces: HashMap<ServiceId, Vec<String>>,
165
166 pub readonly: Vec<String>,
168
169 pub max_size: usize,
171}
172
173impl Default for ContextNamespaceConfig {
174 fn default() -> Self {
175 Self {
176 public: vec![
177 "workflow".to_string(), "input".to_string(),
178 "output".to_string(), "variables".to_string(),
179 ],
180 service_namespaces: HashMap::new(),
181 readonly: vec![
182 "input".to_string(), "system".to_string(),
183 ],
184 max_size: 1024,
185 }
186 }
187}
188
189#[derive(Debug, Default)]
191struct ContextStore {
192 namespaces: HashMap<String, HashMap<String, JsonValue>>,
194}
195
196impl ContextStore {
197 fn new() -> Self {
198 Self::default()
199 }
200
201 fn get(&self, namespace: &str, key: &str) -> Option<&JsonValue> {
202 self.namespaces.get(namespace)?.get(key)
203 }
204
205 fn set(&mut self, namespace: &str, key: &str, value: JsonValue) {
206 self.namespaces
207 .entry(namespace.to_string())
208 .or_insert_with(HashMap::new)
209 .insert(key.to_string(), value);
210 }
211
212 fn delete(&mut self, namespace: &str, key: &str) -> Option<JsonValue> {
213 self.namespaces.get_mut(namespace)?.remove(key)
214 }
215
216 fn list(&self, namespace: &str) -> Vec<String> {
217 self.namespaces
218 .get(namespace)
219 .map(|ns| ns.keys().cloned().collect())
220 .unwrap_or_default()
221 }
222
223 fn exists(&self, namespace: &str, key: &str) -> bool {
224 self.namespaces
225 .get(namespace)
226 .map(|ns| ns.contains_key(key))
227 .unwrap_or(false)
228 }
229
230 fn clear(&mut self, namespace: &str) {
231 if let Some(ns) = self.namespaces.get_mut(namespace) {
232 ns.clear();
233 }
234 }
235}
236
237pub struct ContextCallbackHandler {
241 security: Arc<SecurityValidator>,
243
244 store: Arc<tokio::sync::RwLock<ContextStore>>,
246
247 namespace_config: ContextNamespaceConfig,
249}
250
251impl ContextCallbackHandler {
252 pub fn new(security: Arc<SecurityValidator>) -> Self {
254 Self {
255 security,
256 store: Arc::new(tokio::sync::RwLock::new(ContextStore::new())),
257 namespace_config: ContextNamespaceConfig::default(),
258 }
259 }
260
261 pub fn with_namespace_config(mut self, config: ContextNamespaceConfig) -> Self {
263 self.namespace_config = config;
264 self
265 }
266
267 pub async fn initialize_context(&self, namespace: &str, data: HashMap<String, JsonValue>) {
269 let mut store = self.store.write().await;
270 store.namespaces.insert(namespace.to_string(), data);
271 }
272
273 pub async fn handle(&self, request: ContextCallbackRequest) -> Result<ContextCallbackResult, ContextCallbackError> {
275 let validation = self
277 .security
278 .validate(&request.token, &request.service_id, &request.request_id, "context")
279 .await;
280
281 if !validation.is_valid {
282 return Err(ContextCallbackError::SecurityFailed(
283 validation.error.unwrap_or_else(|| "Unknown security error".to_string()),
284 ));
285 }
286
287 let namespace = request.namespace.clone().unwrap_or_else(|| "workflow".to_string());
289
290 if !self.is_namespace_accessible(&namespace, &request.service_id) {
292 return Err(ContextCallbackError::NamespaceNotAccessible(namespace));
293 }
294
295 if self.namespace_config.readonly.contains(&namespace)
297 && matches!(
298 request.operation,
299 ContextOperation::Set | ContextOperation::Delete | ContextOperation::Clear
300 )
301 {
302 return Err(ContextCallbackError::ReadOnly(request.operation.as_str().to_string()));
303 }
304
305 let mut store = self.store.write().await;
306
307 match request.operation {
308 ContextOperation::Get => {
309 let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
310 let value = store
311 .get(&namespace, &key)
312 .cloned()
313 .ok_or_else(|| ContextCallbackError::KeyNotFound(key.clone()))?;
314
315 Ok(ContextCallbackResult {
316 operation: "get".to_string(),
317 key: Some(key),
318 value: Some(value),
319 keys: vec![],
320 exists: None,
321 status: "success".to_string(),
322 metadata: serde_json::json!({
323 "namespace": namespace,
324 "request_id": request.request_id,
325 }),
326 })
327 }
328
329 ContextOperation::Set => {
330 let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
331 let value = request.value.clone().ok_or(ContextCallbackError::MissingValue)?;
332
333 store.set(&namespace, &key, value.clone());
334
335 Ok(ContextCallbackResult {
336 operation: "set".to_string(),
337 key: Some(key),
338 value: Some(value),
339 keys: vec![],
340 exists: None,
341 status: "success".to_string(),
342 metadata: serde_json::json!({
343 "namespace": namespace,
344 "request_id": request.request_id,
345 }),
346 })
347 }
348
349 ContextOperation::Delete => {
350 let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
351 let existed = store.delete(&namespace, &key).is_some();
352
353 Ok(ContextCallbackResult {
354 operation: "delete".to_string(),
355 key: Some(key),
356 value: None,
357 keys: vec![],
358 exists: Some(existed),
359 status: if existed { "success" } else { "not_found" }.to_string(),
360 metadata: serde_json::json!({
361 "namespace": namespace,
362 "request_id": request.request_id,
363 }),
364 })
365 }
366
367 ContextOperation::List => {
368 let keys = store.list(&namespace);
369 let keys_count = keys.len();
370
371 Ok(ContextCallbackResult {
372 operation: "list".to_string(),
373 key: None,
374 value: None,
375 keys,
376 exists: None,
377 status: "success".to_string(),
378 metadata: serde_json::json!({
379 "namespace": namespace,
380 "request_id": request.request_id,
381 "count": keys_count,
382 }),
383 })
384 }
385
386 ContextOperation::Exists => {
387 let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
388 let exists = store.exists(&namespace, &key);
389
390 Ok(ContextCallbackResult {
391 operation: "exists".to_string(),
392 key: Some(key),
393 value: None,
394 keys: vec![],
395 exists: Some(exists),
396 status: "success".to_string(),
397 metadata: serde_json::json!({
398 "namespace": namespace,
399 "request_id": request.request_id,
400 }),
401 })
402 }
403
404 ContextOperation::Clear => {
405 store.clear(&namespace);
406
407 Ok(ContextCallbackResult {
408 operation: "clear".to_string(),
409 key: None,
410 value: None,
411 keys: vec![],
412 exists: None,
413 status: "success".to_string(),
414 metadata: serde_json::json!({
415 "namespace": namespace,
416 "request_id": request.request_id,
417 }),
418 })
419 }
420 }
421 }
422
423 fn is_namespace_accessible(&self, namespace: &str, service_id: &ServiceId) -> bool {
425 if self.namespace_config.public.contains(&namespace.to_string()) {
427 return true;
428 }
429
430 if let Some(namespaces) = self.namespace_config.service_namespaces.get(service_id) {
432 if namespaces.contains(&namespace.to_string()) {
433 return true;
434 }
435 }
436
437 false
438 }
439
440 pub fn create_error_response(&self, error: ContextCallbackError, id: JsonRpcId) -> JsonRpcResponse {
442 let (code, message, data) = match error {
443 ContextCallbackError::SecurityFailed(msg) => (
444 ErrorCode::PERMISSION_DENIED,
445 "Security validation failed".to_string(),
446 Some(serde_json::json!({ "reason": msg })),
447 ),
448 ContextCallbackError::KeyNotFound(key) => (
449 ErrorCode::RESOURCE_NOT_FOUND,
450 format!("Context key '{}' not found", key),
451 None,
452 ),
453 ContextCallbackError::KeyExists(key) => (
454 ErrorCode::RESOURCE_EXISTS,
455 format!("Context key '{}' already exists", key),
456 None,
457 ),
458 ContextCallbackError::InvalidOperation(op) => (
459 ErrorCode::INVALID_PARAMS,
460 format!("Invalid context operation: {}", op),
461 None,
462 ),
463 ContextCallbackError::MissingKey => (
464 ErrorCode::INVALID_PARAMS,
465 "Missing key for context operation".to_string(),
466 None,
467 ),
468 ContextCallbackError::MissingValue => (
469 ErrorCode::INVALID_PARAMS,
470 "Missing value for Set operation".to_string(),
471 None,
472 ),
473 ContextCallbackError::NamespaceNotAccessible(ns) => (
474 ErrorCode::PERMISSION_DENIED,
475 format!("Namespace '{}' is not accessible", ns),
476 None,
477 ),
478 ContextCallbackError::ReadOnly(op) => (
479 ErrorCode::PERMISSION_DENIED,
480 format!("Context is read-only, cannot perform {} operation", op),
481 None,
482 ),
483 ContextCallbackError::Internal(msg) => (
484 ErrorCode::INTERNAL_ERROR,
485 msg,
486 None,
487 ),
488 };
489
490 JsonRpcResponse::error(
491 id,
492 JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
493 )
494 }
495
496 pub fn get_available_namespaces(&self, service_id: &ServiceId) -> Vec<String> {
498 let mut namespaces = self.namespace_config.public.clone();
499
500 if let Some(service_ns) = self.namespace_config.service_namespaces.get(service_id) {
501 namespaces.extend(service_ns.clone());
502 }
503
504 namespaces
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[tokio::test]
513 async fn test_context_callback_handler_creation() {
514 let security = Arc::new(SecurityValidator::new());
515 let handler = ContextCallbackHandler::new(security);
516
517 assert!(!handler.namespace_config.public.is_empty());
518 }
519
520 #[tokio::test]
521 async fn test_initialize_context() {
522 let security = Arc::new(SecurityValidator::new());
523 let handler = ContextCallbackHandler::new(security);
524
525 let data = HashMap::from([
526 ("key1".to_string(), serde_json::json!("value1")),
527 ("key2".to_string(), serde_json::json!(42)),
528 ]);
529
530 handler.initialize_context("workflow", data).await;
531
532 let store = handler.store.read().await;
534 let keys = store.list("workflow");
535 assert_eq!(keys.len(), 2);
536 }
537
538 #[tokio::test]
539 async fn test_context_get() {
540 let security = Arc::new(SecurityValidator::new());
541 let handler = ContextCallbackHandler::new(security.clone());
542
543 handler
545 .initialize_context(
546 "workflow",
547 HashMap::from([("test_key".to_string(), serde_json::json!("test_value"))]),
548 )
549 .await;
550
551 let service_id = ServiceId::new("test-service");
553 let request_id = "req-001".to_string();
554 let token = security
555 .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
556 .await
557 .unwrap();
558
559 let request = ContextCallbackRequest {
560 request_id,
561 service_id,
562 token,
563 operation: ContextOperation::Get,
564 key: Some("test_key".to_string()),
565 value: None,
566 namespace: Some("workflow".to_string()),
567 };
568
569 let result = handler.handle(request).await.unwrap();
570 assert_eq!(result.operation, "get");
571 assert_eq!(result.key, Some("test_key".to_string()));
572 assert_eq!(result.value, Some(serde_json::json!("test_value")));
573 }
574
575 #[tokio::test]
576 async fn test_context_set() {
577 let security = Arc::new(SecurityValidator::new());
578 let handler = ContextCallbackHandler::new(security.clone());
579
580 let service_id = ServiceId::new("test-service");
582 let request_id = "req-001".to_string();
583 let token = security
584 .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
585 .await
586 .unwrap();
587
588 let request = ContextCallbackRequest {
589 request_id,
590 service_id,
591 token,
592 operation: ContextOperation::Set,
593 key: Some("new_key".to_string()),
594 value: Some(serde_json::json!("new_value")),
595 namespace: Some("workflow".to_string()),
596 };
597
598 let result = handler.handle(request).await.unwrap();
599 assert_eq!(result.operation, "set");
600 assert_eq!(result.status, "success");
601 }
602
603 #[tokio::test]
604 async fn test_context_list() {
605 let security = Arc::new(SecurityValidator::new());
606 let handler = ContextCallbackHandler::new(security.clone());
607
608 handler
609 .initialize_context(
610 "workflow",
611 HashMap::from([
612 ("key1".to_string(), serde_json::json!(1)),
613 ("key2".to_string(), serde_json::json!(2)),
614 ]),
615 )
616 .await;
617
618 let service_id = ServiceId::new("test-service");
619 let request_id = "req-001".to_string();
620 let token = security
621 .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
622 .await
623 .unwrap();
624
625 let request = ContextCallbackRequest {
626 request_id,
627 service_id,
628 token,
629 operation: ContextOperation::List,
630 key: None,
631 value: None,
632 namespace: Some("workflow".to_string()),
633 };
634
635 let result = handler.handle(request).await.unwrap();
636 assert_eq!(result.keys.len(), 2);
637 }
638
639 #[tokio::test]
640 async fn test_context_exists() {
641 let security = Arc::new(SecurityValidator::new());
642 let handler = ContextCallbackHandler::new(security.clone());
643
644 handler
645 .initialize_context(
646 "workflow",
647 HashMap::from([("existing_key".to_string(), serde_json::json!("value"))]),
648 )
649 .await;
650
651 let service_id = ServiceId::new("test-service");
652 let request_id = "req-001".to_string();
653 let token = security
654 .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
655 .await
656 .unwrap();
657
658 let request = ContextCallbackRequest {
660 request_id,
661 service_id,
662 token,
663 operation: ContextOperation::Exists,
664 key: Some("existing_key".to_string()),
665 value: None,
666 namespace: Some("workflow".to_string()),
667 };
668
669 let result = handler.handle(request).await.unwrap();
670 assert_eq!(result.exists, Some(true));
671 }
672
673 #[tokio::test]
674 async fn test_context_readonly_namespace() {
675 let security = Arc::new(SecurityValidator::new());
676 let handler = ContextCallbackHandler::new(security.clone());
677
678 let service_id = ServiceId::new("test-service");
679 let request_id = "req-001".to_string();
680 let token = security
681 .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
682 .await
683 .unwrap();
684
685 let request = ContextCallbackRequest {
687 request_id,
688 service_id,
689 token,
690 operation: ContextOperation::Set,
691 key: Some("key".to_string()),
692 value: Some(serde_json::json!("value")),
693 namespace: Some("input".to_string()),
694 };
695
696 let result = handler.handle(request).await;
697 assert!(matches!(result, Err(ContextCallbackError::ReadOnly(_))));
698 }
699
700 #[test]
701 fn test_namespace_accessible() {
702 let security = Arc::new(SecurityValidator::new());
703 let handler = ContextCallbackHandler::new(security);
704
705 assert!(handler.is_namespace_accessible("workflow", &ServiceId::new("any")));
707
708 assert!(!handler.is_namespace_accessible("private", &ServiceId::new("any")));
710 }
711
712 #[test]
713 fn test_get_available_namespaces() {
714 let security = Arc::new(SecurityValidator::new());
715 let handler = ContextCallbackHandler::new(security);
716
717 let namespaces = handler.get_available_namespaces(&ServiceId::new("test"));
718 assert!(namespaces.contains(&"workflow".to_string()));
719 assert!(namespaces.contains(&"input".to_string()));
720 }
721}