Skip to main content

matrixcode_core/matrixrpc/callback/
context.rs

1//! Context Callback Handler
2//!
3//! Handles context callback requests from external services.
4//! Enables external nodes to access workflow context data.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value as JsonValue;
11
12use super::security::SecurityValidator;
13use crate::matrixrpc::{ErrorCode, JsonRpcError, JsonRpcId, JsonRpcResponse, ServiceId};
14
15/// Context operation type
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum ContextOperation {
19    /// Get a value from context
20    Get,
21
22    /// Set a value in context
23    Set,
24
25    /// Delete a value from context
26    Delete,
27
28    /// List all keys
29    List,
30
31    /// Check if key exists
32    Exists,
33
34    /// Clear all context
35    Clear,
36}
37
38impl Default for ContextOperation {
39    fn default() -> Self {
40        Self::Get
41    }
42}
43
44impl ContextOperation {
45    /// Get the string representation
46    pub fn as_str(&self) -> &'static str {
47        match self {
48            ContextOperation::Get => "get",
49            ContextOperation::Set => "set",
50            ContextOperation::Delete => "delete",
51            ContextOperation::List => "list",
52            ContextOperation::Exists => "exists",
53            ContextOperation::Clear => "clear",
54        }
55    }
56}
57
58/// Context callback request
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ContextCallbackRequest {
61    /// Request ID from original node execution
62    pub request_id: String,
63
64    /// Service ID making the callback
65    pub service_id: ServiceId,
66
67    /// Security token
68    pub token: String,
69
70    /// Operation to perform
71    #[serde(default)]
72    pub operation: ContextOperation,
73
74    /// Key to access
75    #[serde(default)]
76    pub key: Option<String>,
77
78    /// Value to set (for Set operation)
79    #[serde(default)]
80    pub value: Option<JsonValue>,
81
82    /// Namespace for the key
83    #[serde(default)]
84    pub namespace: Option<String>,
85}
86
87/// Context callback result
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ContextCallbackResult {
90    /// Operation that was performed
91    pub operation: String,
92
93    /// Key that was accessed
94    #[serde(default)]
95    pub key: Option<String>,
96
97    /// Value (for Get operation)
98    #[serde(default)]
99    pub value: Option<JsonValue>,
100
101    /// Keys (for List operation)
102    #[serde(default)]
103    pub keys: Vec<String>,
104
105    /// Whether key exists (for Exists operation)
106    #[serde(default)]
107    pub exists: Option<bool>,
108
109    /// Status message
110    pub status: String,
111
112    /// Additional metadata
113    #[serde(default)]
114    pub metadata: JsonValue,
115}
116
117/// Context callback error
118#[derive(Debug, thiserror::Error)]
119pub enum ContextCallbackError {
120    /// Security validation failed
121    #[error("Security validation failed: {0}")]
122    SecurityFailed(String),
123
124    /// Key not found
125    #[error("Context key '{0}' not found")]
126    KeyNotFound(String),
127
128    /// Key already exists
129    #[error("Context key '{0}' already exists")]
130    KeyExists(String),
131
132    /// Invalid operation
133    #[error("Invalid context operation: {0}")]
134    InvalidOperation(String),
135
136    /// Missing key
137    #[error("Missing key for context operation")]
138    MissingKey,
139
140    /// Missing value
141    #[error("Missing value for Set operation")]
142    MissingValue,
143
144    /// Namespace not accessible
145    #[error("Namespace '{0}' is not accessible")]
146    NamespaceNotAccessible(String),
147
148    /// Read-only context
149    #[error("Context is read-only, cannot perform {0} operation")]
150    ReadOnly(String),
151
152    /// Internal error
153    #[error("Internal error: {0}")]
154    Internal(String),
155}
156
157/// Context namespace configuration
158#[derive(Debug, Clone)]
159pub struct ContextNamespaceConfig {
160    /// Public namespaces (accessible to all)
161    pub public: Vec<String>,
162
163    /// Service-specific namespaces
164    pub service_namespaces: HashMap<ServiceId, Vec<String>>,
165
166    /// Read-only namespaces
167    pub readonly: Vec<String>,
168
169    /// Maximum context size per namespace
170    pub max_size: usize,
171}
172
173impl Default for ContextNamespaceConfig {
174    fn default() -> Self {
175        Self {
176            public: vec![
177                "workflow".to_string(), "input".to_string(),
178                "output".to_string(), "variables".to_string(),
179            ],
180            service_namespaces: HashMap::new(),
181            readonly: vec![
182                "input".to_string(), "system".to_string(),
183            ],
184            max_size: 1024,
185        }
186    }
187}
188
189/// Context store
190#[derive(Debug, Default)]
191struct ContextStore {
192    /// Data store by namespace
193    namespaces: HashMap<String, HashMap<String, JsonValue>>,
194}
195
196impl ContextStore {
197    fn new() -> Self {
198        Self::default()
199    }
200
201    fn get(&self, namespace: &str, key: &str) -> Option<&JsonValue> {
202        self.namespaces.get(namespace)?.get(key)
203    }
204
205    fn set(&mut self, namespace: &str, key: &str, value: JsonValue) {
206        self.namespaces
207            .entry(namespace.to_string())
208            .or_insert_with(HashMap::new)
209            .insert(key.to_string(), value);
210    }
211
212    fn delete(&mut self, namespace: &str, key: &str) -> Option<JsonValue> {
213        self.namespaces.get_mut(namespace)?.remove(key)
214    }
215
216    fn list(&self, namespace: &str) -> Vec<String> {
217        self.namespaces
218            .get(namespace)
219            .map(|ns| ns.keys().cloned().collect())
220            .unwrap_or_default()
221    }
222
223    fn exists(&self, namespace: &str, key: &str) -> bool {
224        self.namespaces
225            .get(namespace)
226            .map(|ns| ns.contains_key(key))
227            .unwrap_or(false)
228    }
229
230    fn clear(&mut self, namespace: &str) {
231        if let Some(ns) = self.namespaces.get_mut(namespace) {
232            ns.clear();
233        }
234    }
235}
236
237/// Context Callback Handler
238///
239/// Handles context callback requests from external extension services.
240pub struct ContextCallbackHandler {
241    /// Security validator
242    security: Arc<SecurityValidator>,
243
244    /// Context store
245    store: Arc<tokio::sync::RwLock<ContextStore>>,
246
247    /// Namespace configuration
248    namespace_config: ContextNamespaceConfig,
249}
250
251impl ContextCallbackHandler {
252    /// Create a new context callback handler
253    pub fn new(security: Arc<SecurityValidator>) -> Self {
254        Self {
255            security,
256            store: Arc::new(tokio::sync::RwLock::new(ContextStore::new())),
257            namespace_config: ContextNamespaceConfig::default(),
258        }
259    }
260
261    /// Set namespace configuration
262    pub fn with_namespace_config(mut self, config: ContextNamespaceConfig) -> Self {
263        self.namespace_config = config;
264        self
265    }
266
267    /// Initialize with existing context
268    pub async fn initialize_context(&self, namespace: &str, data: HashMap<String, JsonValue>) {
269        let mut store = self.store.write().await;
270        store.namespaces.insert(namespace.to_string(), data);
271    }
272
273    /// Handle a context callback request
274    pub async fn handle(&self, request: ContextCallbackRequest) -> Result<ContextCallbackResult, ContextCallbackError> {
275        // Validate security
276        let validation = self
277            .security
278            .validate(&request.token, &request.service_id, &request.request_id, "context")
279            .await;
280
281        if !validation.is_valid {
282            return Err(ContextCallbackError::SecurityFailed(
283                validation.error.unwrap_or_else(|| "Unknown security error".to_string()),
284            ));
285        }
286
287        // Get namespace (default to "workflow")
288        let namespace = request.namespace.clone().unwrap_or_else(|| "workflow".to_string());
289
290        // Check namespace accessibility
291        if !self.is_namespace_accessible(&namespace, &request.service_id) {
292            return Err(ContextCallbackError::NamespaceNotAccessible(namespace));
293        }
294
295        // Check if read-only namespace for write operations
296        if self.namespace_config.readonly.contains(&namespace)
297            && matches!(
298                request.operation,
299                ContextOperation::Set | ContextOperation::Delete | ContextOperation::Clear
300            )
301        {
302            return Err(ContextCallbackError::ReadOnly(request.operation.as_str().to_string()));
303        }
304
305        let mut store = self.store.write().await;
306
307        match request.operation {
308            ContextOperation::Get => {
309                let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
310                let value = store
311                    .get(&namespace, &key)
312                    .cloned()
313                    .ok_or_else(|| ContextCallbackError::KeyNotFound(key.clone()))?;
314
315                Ok(ContextCallbackResult {
316                    operation: "get".to_string(),
317                    key: Some(key),
318                    value: Some(value),
319                    keys: vec![],
320                    exists: None,
321                    status: "success".to_string(),
322                    metadata: serde_json::json!({
323                        "namespace": namespace,
324                        "request_id": request.request_id,
325                    }),
326                })
327            }
328
329            ContextOperation::Set => {
330                let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
331                let value = request.value.clone().ok_or(ContextCallbackError::MissingValue)?;
332
333                store.set(&namespace, &key, value.clone());
334
335                Ok(ContextCallbackResult {
336                    operation: "set".to_string(),
337                    key: Some(key),
338                    value: Some(value),
339                    keys: vec![],
340                    exists: None,
341                    status: "success".to_string(),
342                    metadata: serde_json::json!({
343                        "namespace": namespace,
344                        "request_id": request.request_id,
345                    }),
346                })
347            }
348
349            ContextOperation::Delete => {
350                let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
351                let existed = store.delete(&namespace, &key).is_some();
352
353                Ok(ContextCallbackResult {
354                    operation: "delete".to_string(),
355                    key: Some(key),
356                    value: None,
357                    keys: vec![],
358                    exists: Some(existed),
359                    status: if existed { "success" } else { "not_found" }.to_string(),
360                    metadata: serde_json::json!({
361                        "namespace": namespace,
362                        "request_id": request.request_id,
363                    }),
364                })
365            }
366
367            ContextOperation::List => {
368                let keys = store.list(&namespace);
369                let keys_count = keys.len();
370
371                Ok(ContextCallbackResult {
372                    operation: "list".to_string(),
373                    key: None,
374                    value: None,
375                    keys,
376                    exists: None,
377                    status: "success".to_string(),
378                    metadata: serde_json::json!({
379                        "namespace": namespace,
380                        "request_id": request.request_id,
381                        "count": keys_count,
382                    }),
383                })
384            }
385
386            ContextOperation::Exists => {
387                let key = request.key.clone().ok_or(ContextCallbackError::MissingKey)?;
388                let exists = store.exists(&namespace, &key);
389
390                Ok(ContextCallbackResult {
391                    operation: "exists".to_string(),
392                    key: Some(key),
393                    value: None,
394                    keys: vec![],
395                    exists: Some(exists),
396                    status: "success".to_string(),
397                    metadata: serde_json::json!({
398                        "namespace": namespace,
399                        "request_id": request.request_id,
400                    }),
401                })
402            }
403
404            ContextOperation::Clear => {
405                store.clear(&namespace);
406
407                Ok(ContextCallbackResult {
408                    operation: "clear".to_string(),
409                    key: None,
410                    value: None,
411                    keys: vec![],
412                    exists: None,
413                    status: "success".to_string(),
414                    metadata: serde_json::json!({
415                        "namespace": namespace,
416                        "request_id": request.request_id,
417                    }),
418                })
419            }
420        }
421    }
422
423    /// Check if namespace is accessible for a service
424    fn is_namespace_accessible(&self, namespace: &str, service_id: &ServiceId) -> bool {
425        // Public namespaces are accessible to all
426        if self.namespace_config.public.contains(&namespace.to_string()) {
427            return true;
428        }
429
430        // Check service-specific namespaces
431        if let Some(namespaces) = self.namespace_config.service_namespaces.get(service_id) {
432            if namespaces.contains(&namespace.to_string()) {
433                return true;
434            }
435        }
436
437        false
438    }
439
440    /// Create a JSON-RPC error response for context callback failures
441    pub fn create_error_response(&self, error: ContextCallbackError, id: JsonRpcId) -> JsonRpcResponse {
442        let (code, message, data) = match error {
443            ContextCallbackError::SecurityFailed(msg) => (
444                ErrorCode::PERMISSION_DENIED,
445                "Security validation failed".to_string(),
446                Some(serde_json::json!({ "reason": msg })),
447            ),
448            ContextCallbackError::KeyNotFound(key) => (
449                ErrorCode::RESOURCE_NOT_FOUND,
450                format!("Context key '{}' not found", key),
451                None,
452            ),
453            ContextCallbackError::KeyExists(key) => (
454                ErrorCode::RESOURCE_EXISTS,
455                format!("Context key '{}' already exists", key),
456                None,
457            ),
458            ContextCallbackError::InvalidOperation(op) => (
459                ErrorCode::INVALID_PARAMS,
460                format!("Invalid context operation: {}", op),
461                None,
462            ),
463            ContextCallbackError::MissingKey => (
464                ErrorCode::INVALID_PARAMS,
465                "Missing key for context operation".to_string(),
466                None,
467            ),
468            ContextCallbackError::MissingValue => (
469                ErrorCode::INVALID_PARAMS,
470                "Missing value for Set operation".to_string(),
471                None,
472            ),
473            ContextCallbackError::NamespaceNotAccessible(ns) => (
474                ErrorCode::PERMISSION_DENIED,
475                format!("Namespace '{}' is not accessible", ns),
476                None,
477            ),
478            ContextCallbackError::ReadOnly(op) => (
479                ErrorCode::PERMISSION_DENIED,
480                format!("Context is read-only, cannot perform {} operation", op),
481                None,
482            ),
483            ContextCallbackError::Internal(msg) => (
484                ErrorCode::INTERNAL_ERROR,
485                msg,
486                None,
487            ),
488        };
489
490        JsonRpcResponse::error(
491            id,
492            JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
493        )
494    }
495
496    /// Get all available namespaces for a service
497    pub fn get_available_namespaces(&self, service_id: &ServiceId) -> Vec<String> {
498        let mut namespaces = self.namespace_config.public.clone();
499
500        if let Some(service_ns) = self.namespace_config.service_namespaces.get(service_id) {
501            namespaces.extend(service_ns.clone());
502        }
503
504        namespaces
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[tokio::test]
513    async fn test_context_callback_handler_creation() {
514        let security = Arc::new(SecurityValidator::new());
515        let handler = ContextCallbackHandler::new(security);
516
517        assert!(!handler.namespace_config.public.is_empty());
518    }
519
520    #[tokio::test]
521    async fn test_initialize_context() {
522        let security = Arc::new(SecurityValidator::new());
523        let handler = ContextCallbackHandler::new(security);
524
525        let data = HashMap::from([
526            ("key1".to_string(), serde_json::json!("value1")),
527            ("key2".to_string(), serde_json::json!(42)),
528        ]);
529
530        handler.initialize_context("workflow", data).await;
531
532        // Verify by doing a List operation
533        let store = handler.store.read().await;
534        let keys = store.list("workflow");
535        assert_eq!(keys.len(), 2);
536    }
537
538    #[tokio::test]
539    async fn test_context_get() {
540        let security = Arc::new(SecurityValidator::new());
541        let handler = ContextCallbackHandler::new(security.clone());
542
543        // Initialize context
544        handler
545            .initialize_context(
546                "workflow",
547                HashMap::from([("test_key".to_string(), serde_json::json!("test_value"))]),
548            )
549            .await;
550
551        // Generate token
552        let service_id = ServiceId::new("test-service");
553        let request_id = "req-001".to_string();
554        let token = security
555            .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
556            .await
557            .unwrap();
558
559        let request = ContextCallbackRequest {
560            request_id,
561            service_id,
562            token,
563            operation: ContextOperation::Get,
564            key: Some("test_key".to_string()),
565            value: None,
566            namespace: Some("workflow".to_string()),
567        };
568
569        let result = handler.handle(request).await.unwrap();
570        assert_eq!(result.operation, "get");
571        assert_eq!(result.key, Some("test_key".to_string()));
572        assert_eq!(result.value, Some(serde_json::json!("test_value")));
573    }
574
575    #[tokio::test]
576    async fn test_context_set() {
577        let security = Arc::new(SecurityValidator::new());
578        let handler = ContextCallbackHandler::new(security.clone());
579
580        // Generate token
581        let service_id = ServiceId::new("test-service");
582        let request_id = "req-001".to_string();
583        let token = security
584            .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
585            .await
586            .unwrap();
587
588        let request = ContextCallbackRequest {
589            request_id,
590            service_id,
591            token,
592            operation: ContextOperation::Set,
593            key: Some("new_key".to_string()),
594            value: Some(serde_json::json!("new_value")),
595            namespace: Some("workflow".to_string()),
596        };
597
598        let result = handler.handle(request).await.unwrap();
599        assert_eq!(result.operation, "set");
600        assert_eq!(result.status, "success");
601    }
602
603    #[tokio::test]
604    async fn test_context_list() {
605        let security = Arc::new(SecurityValidator::new());
606        let handler = ContextCallbackHandler::new(security.clone());
607
608        handler
609            .initialize_context(
610                "workflow",
611                HashMap::from([
612                    ("key1".to_string(), serde_json::json!(1)),
613                    ("key2".to_string(), serde_json::json!(2)),
614                ]),
615            )
616            .await;
617
618        let service_id = ServiceId::new("test-service");
619        let request_id = "req-001".to_string();
620        let token = security
621            .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
622            .await
623            .unwrap();
624
625        let request = ContextCallbackRequest {
626            request_id,
627            service_id,
628            token,
629            operation: ContextOperation::List,
630            key: None,
631            value: None,
632            namespace: Some("workflow".to_string()),
633        };
634
635        let result = handler.handle(request).await.unwrap();
636        assert_eq!(result.keys.len(), 2);
637    }
638
639    #[tokio::test]
640    async fn test_context_exists() {
641        let security = Arc::new(SecurityValidator::new());
642        let handler = ContextCallbackHandler::new(security.clone());
643
644        handler
645            .initialize_context(
646                "workflow",
647                HashMap::from([("existing_key".to_string(), serde_json::json!("value"))]),
648            )
649            .await;
650
651        let service_id = ServiceId::new("test-service");
652        let request_id = "req-001".to_string();
653        let token = security
654            .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
655            .await
656            .unwrap();
657
658        // Test existing key
659        let request = ContextCallbackRequest {
660            request_id,
661            service_id,
662            token,
663            operation: ContextOperation::Exists,
664            key: Some("existing_key".to_string()),
665            value: None,
666            namespace: Some("workflow".to_string()),
667        };
668
669        let result = handler.handle(request).await.unwrap();
670        assert_eq!(result.exists, Some(true));
671    }
672
673    #[tokio::test]
674    async fn test_context_readonly_namespace() {
675        let security = Arc::new(SecurityValidator::new());
676        let handler = ContextCallbackHandler::new(security.clone());
677
678        let service_id = ServiceId::new("test-service");
679        let request_id = "req-001".to_string();
680        let token = security
681            .generate_token(service_id.clone(), request_id.clone(), vec!["context".to_string()])
682            .await
683            .unwrap();
684
685        // "input" is read-only by default
686        let request = ContextCallbackRequest {
687            request_id,
688            service_id,
689            token,
690            operation: ContextOperation::Set,
691            key: Some("key".to_string()),
692            value: Some(serde_json::json!("value")),
693            namespace: Some("input".to_string()),
694        };
695
696        let result = handler.handle(request).await;
697        assert!(matches!(result, Err(ContextCallbackError::ReadOnly(_))));
698    }
699
700    #[test]
701    fn test_namespace_accessible() {
702        let security = Arc::new(SecurityValidator::new());
703        let handler = ContextCallbackHandler::new(security);
704
705        // Public namespace
706        assert!(handler.is_namespace_accessible("workflow", &ServiceId::new("any")));
707
708        // Not accessible (not public and not in service namespaces)
709        assert!(!handler.is_namespace_accessible("private", &ServiceId::new("any")));
710    }
711
712    #[test]
713    fn test_get_available_namespaces() {
714        let security = Arc::new(SecurityValidator::new());
715        let handler = ContextCallbackHandler::new(security);
716
717        let namespaces = handler.get_available_namespaces(&ServiceId::new("test"));
718        assert!(namespaces.contains(&"workflow".to_string()));
719        assert!(namespaces.contains(&"input".to_string()));
720    }
721}