Skip to main content

oxirs_arq/
websocket_streaming.rs

1//! WebSocket Streaming for SPARQL Query Results
2//!
3//! Provides real-time streaming of SPARQL query results over WebSocket connections.
4//! Supports incremental result delivery, query cancellation, and backpressure handling.
5
6use crate::algebra::Variable;
7use anyhow::{anyhow, Result};
8use scirs2_core::metrics::{Counter, Gauge, Timer};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::{mpsc, RwLock};
14
15/// Configuration for WebSocket streaming
16#[derive(Debug, Clone)]
17pub struct WebSocketConfig {
18    /// Maximum message size in bytes
19    pub max_message_size: usize,
20    /// Buffer size for streaming results
21    pub buffer_size: usize,
22    /// Ping interval for keepalive
23    pub ping_interval: Duration,
24    /// Connection timeout
25    pub connection_timeout: Duration,
26    /// Maximum concurrent connections
27    pub max_connections: usize,
28    /// Enable result compression
29    pub enable_compression: bool,
30    /// Batch size for result streaming
31    pub batch_size: usize,
32}
33
34impl Default for WebSocketConfig {
35    fn default() -> Self {
36        Self {
37            max_message_size: 16 * 1024 * 1024, // 16 MB
38            buffer_size: 10000,
39            ping_interval: Duration::from_secs(30),
40            connection_timeout: Duration::from_secs(300),
41            max_connections: 1000,
42            enable_compression: true,
43            batch_size: 100,
44        }
45    }
46}
47
48/// WebSocket message types
49#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "type")]
51pub enum WebSocketMessage {
52    /// Query request
53    Query {
54        id: String,
55        sparql: String,
56        bindings: Option<HashMap<String, String>>,
57    },
58    /// Query result batch
59    ResultBatch {
60        id: String,
61        variables: Vec<String>,
62        solutions: Vec<HashMap<String, String>>,
63        more: bool,
64    },
65    /// Query completion
66    QueryComplete { id: String, total_results: usize },
67    /// Query error
68    QueryError { id: String, error: String },
69    /// Query cancellation request
70    CancelQuery { id: String },
71    /// Query cancelled confirmation
72    QueryCancelled { id: String },
73    /// Server ping
74    Ping,
75    /// Client pong
76    Pong,
77    /// Connection statistics
78    Stats { stats: ConnectionStats },
79}
80
81/// WebSocket streaming session
82pub struct WebSocketSession {
83    /// Session ID
84    id: String,
85    /// Configuration
86    config: WebSocketConfig,
87    /// Active queries
88    active_queries: Arc<RwLock<HashMap<String, QuerySession>>>,
89    /// Metrics
90    metrics: Arc<SessionMetrics>,
91    /// Connection start time
92    start_time: Instant,
93}
94
95impl WebSocketSession {
96    /// Create a new WebSocket session
97    pub fn new(id: String, config: WebSocketConfig) -> Self {
98        Self {
99            id,
100            config,
101            active_queries: Arc::new(RwLock::new(HashMap::new())),
102            metrics: Arc::new(SessionMetrics::new()),
103            start_time: Instant::now(),
104        }
105    }
106
107    /// Start a query execution
108    pub async fn start_query(
109        &self,
110        query_id: String,
111        sparql: String,
112    ) -> Result<mpsc::Receiver<WebSocketMessage>> {
113        let (tx, rx) = mpsc::channel(self.config.buffer_size);
114
115        // Create query session
116        let session = QuerySession {
117            id: query_id.clone(),
118            sparql: sparql.clone(),
119            start_time: Instant::now(),
120            results_sent: 0,
121            cancelled: false,
122            sender: tx.clone(),
123        };
124
125        // Register query
126        {
127            let mut queries = self.active_queries.write().await;
128            if queries.len() >= self.config.max_connections {
129                return Err(anyhow!("Maximum concurrent queries reached"));
130            }
131            queries.insert(query_id.clone(), session);
132        }
133
134        self.metrics.active_queries.add(1.0);
135        self.metrics.total_queries.inc();
136
137        Ok(rx)
138    }
139
140    /// Stream results for a query
141    /// Note: Solution is `Vec<Binding>`, where `Binding` is `HashMap<Variable, Term>`
142    /// For streaming, we expect bindings (individual solutions), not Solution (`Vec<Binding>`)
143    pub async fn stream_results(
144        &self,
145        query_id: &str,
146        variables: Vec<Variable>,
147        bindings: Vec<crate::algebra::Binding>,
148    ) -> Result<()> {
149        let query_session = {
150            let queries = self.active_queries.read().await;
151            queries
152                .get(query_id)
153                .ok_or_else(|| anyhow!("Query not found: {}", query_id))?
154                .clone()
155        };
156
157        if query_session.cancelled {
158            return Ok(());
159        }
160
161        // Convert variables to strings
162        let var_names: Vec<String> = variables.iter().map(|v| v.to_string()).collect();
163
164        // Stream bindings in batches
165        for batch in bindings.chunks(self.config.batch_size) {
166            if query_session.is_cancelled() {
167                break;
168            }
169
170            // Convert bindings to string maps
171            let solution_maps: Vec<HashMap<String, String>> = batch
172                .iter()
173                .map(|binding| {
174                    binding
175                        .iter()
176                        .map(|(var, term)| (var.to_string(), format!("{:?}", term)))
177                        .collect()
178                })
179                .collect();
180
181            let message = WebSocketMessage::ResultBatch {
182                id: query_id.to_string(),
183                variables: var_names.clone(),
184                solutions: solution_maps,
185                more: true,
186            };
187
188            query_session
189                .sender
190                .send(message)
191                .await
192                .map_err(|e| anyhow!("Failed to send results: {}", e))?;
193
194            // Update metrics
195            self.metrics.results_sent.add(batch.len() as u64);
196
197            // Update query session
198            {
199                let mut queries = self.active_queries.write().await;
200                if let Some(session) = queries.get_mut(query_id) {
201                    session.results_sent += batch.len();
202                }
203            }
204        }
205
206        // Send completion message
207        let message = WebSocketMessage::QueryComplete {
208            id: query_id.to_string(),
209            total_results: bindings.len(),
210        };
211
212        query_session
213            .sender
214            .send(message)
215            .await
216            .map_err(|e| anyhow!("Failed to send completion: {}", e))?;
217
218        // Clean up query
219        self.complete_query(query_id).await;
220
221        Ok(())
222    }
223
224    /// Cancel a query
225    pub async fn cancel_query(&self, query_id: &str) -> Result<()> {
226        let mut queries = self.active_queries.write().await;
227        if let Some(session) = queries.get_mut(query_id) {
228            session.cancelled = true;
229
230            let message = WebSocketMessage::QueryCancelled {
231                id: query_id.to_string(),
232            };
233
234            let _ = session.sender.send(message).await;
235
236            self.metrics.queries_cancelled.inc();
237        }
238
239        queries.remove(query_id);
240        self.metrics.active_queries.sub(1.0);
241
242        Ok(())
243    }
244
245    /// Complete a query
246    async fn complete_query(&self, query_id: &str) {
247        let mut queries = self.active_queries.write().await;
248        if let Some(session) = queries.remove(query_id) {
249            let duration = session.start_time.elapsed();
250            self.metrics.query_duration.observe(duration);
251            self.metrics.active_queries.sub(1.0);
252            self.metrics.completed_queries.inc();
253        }
254    }
255
256    /// Send error to query
257    pub async fn send_error(&self, query_id: &str, error: String) -> Result<()> {
258        let queries = self.active_queries.read().await;
259        if let Some(session) = queries.get(query_id) {
260            let message = WebSocketMessage::QueryError {
261                id: query_id.to_string(),
262                error,
263            };
264
265            let _ = session.sender.send(message).await;
266            self.metrics.query_errors.inc();
267        }
268
269        Ok(())
270    }
271
272    /// Get session statistics
273    pub async fn statistics(&self) -> ConnectionStats {
274        let queries = self.active_queries.read().await;
275        let stats = self.metrics.query_duration.get_stats();
276
277        ConnectionStats {
278            session_id: self.id.clone(),
279            uptime: self.start_time.elapsed(),
280            active_queries: queries.len(),
281            total_queries: self.metrics.total_queries.get(),
282            completed_queries: self.metrics.completed_queries.get(),
283            cancelled_queries: self.metrics.queries_cancelled.get(),
284            failed_queries: self.metrics.query_errors.get(),
285            results_sent: self.metrics.results_sent.get(),
286            average_query_duration: stats.mean,
287        }
288    }
289
290    /// Check if session is healthy
291    pub fn is_healthy(&self) -> bool {
292        self.start_time.elapsed() < self.config.connection_timeout
293    }
294}
295
296/// Query execution session
297#[derive(Clone)]
298#[allow(dead_code)]
299struct QuerySession {
300    id: String,
301    sparql: String,
302    start_time: Instant,
303    results_sent: usize,
304    cancelled: bool,
305    sender: mpsc::Sender<WebSocketMessage>,
306}
307
308impl QuerySession {
309    fn is_cancelled(&self) -> bool {
310        self.cancelled
311    }
312}
313
314/// Session metrics
315struct SessionMetrics {
316    total_queries: Counter,
317    active_queries: Gauge,
318    completed_queries: Counter,
319    queries_cancelled: Counter,
320    query_errors: Counter,
321    results_sent: Counter,
322    query_duration: Timer,
323}
324
325impl SessionMetrics {
326    fn new() -> Self {
327        Self {
328            total_queries: Counter::new("websocket.total_queries".to_string()),
329            active_queries: Gauge::new("websocket.active_queries".to_string()),
330            completed_queries: Counter::new("websocket.completed_queries".to_string()),
331            queries_cancelled: Counter::new("websocket.queries_cancelled".to_string()),
332            query_errors: Counter::new("websocket.query_errors".to_string()),
333            results_sent: Counter::new("websocket.results_sent".to_string()),
334            query_duration: Timer::new("websocket.query_duration".to_string()),
335        }
336    }
337}
338
339/// Connection statistics
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct ConnectionStats {
342    pub session_id: String,
343    #[serde(serialize_with = "serialize_duration")]
344    pub uptime: Duration,
345    pub active_queries: usize,
346    pub total_queries: u64,
347    pub completed_queries: u64,
348    pub cancelled_queries: u64,
349    pub failed_queries: u64,
350    pub results_sent: u64,
351    pub average_query_duration: f64,
352}
353
354fn serialize_duration<S>(duration: &Duration, serializer: S) -> std::result::Result<S::Ok, S::Error>
355where
356    S: serde::Serializer,
357{
358    serializer.serialize_f64(duration.as_secs_f64())
359}
360
361/// WebSocket session manager
362pub struct WebSocketManager {
363    /// Configuration
364    config: WebSocketConfig,
365    /// Active sessions
366    sessions: Arc<RwLock<HashMap<String, Arc<WebSocketSession>>>>,
367    /// Global metrics
368    metrics: Arc<ManagerMetrics>,
369}
370
371impl WebSocketManager {
372    /// Create a new WebSocket manager
373    pub fn new(config: WebSocketConfig) -> Self {
374        Self {
375            config,
376            sessions: Arc::new(RwLock::new(HashMap::new())),
377            metrics: Arc::new(ManagerMetrics::new()),
378        }
379    }
380
381    /// Create a new session
382    pub async fn create_session(&self, session_id: String) -> Result<Arc<WebSocketSession>> {
383        let mut sessions = self.sessions.write().await;
384
385        if sessions.len() >= self.config.max_connections {
386            return Err(anyhow!("Maximum connections reached"));
387        }
388
389        let session = Arc::new(WebSocketSession::new(
390            session_id.clone(),
391            self.config.clone(),
392        ));
393        sessions.insert(session_id, session.clone());
394
395        self.metrics.active_sessions.add(1.0);
396        self.metrics.total_sessions.inc();
397
398        Ok(session)
399    }
400
401    /// Get a session
402    pub async fn get_session(&self, session_id: &str) -> Option<Arc<WebSocketSession>> {
403        let sessions = self.sessions.read().await;
404        sessions.get(session_id).cloned()
405    }
406
407    /// Remove a session
408    pub async fn remove_session(&self, session_id: &str) -> Result<()> {
409        let mut sessions = self.sessions.write().await;
410        if sessions.remove(session_id).is_some() {
411            self.metrics.active_sessions.sub(1.0);
412            self.metrics.closed_sessions.inc();
413        }
414        Ok(())
415    }
416
417    /// Get manager statistics
418    pub async fn statistics(&self) -> ManagerStats {
419        let sessions = self.sessions.read().await;
420
421        ManagerStats {
422            active_sessions: sessions.len(),
423            total_sessions: self.metrics.total_sessions.get(),
424            closed_sessions: self.metrics.closed_sessions.get(),
425            max_connections: self.config.max_connections,
426        }
427    }
428
429    /// Clean up inactive sessions
430    pub async fn cleanup_inactive_sessions(&self) -> usize {
431        let mut sessions = self.sessions.write().await;
432        let mut removed = 0;
433
434        sessions.retain(|_, session| {
435            if !session.is_healthy() {
436                removed += 1;
437                false
438            } else {
439                true
440            }
441        });
442
443        if removed > 0 {
444            self.metrics.active_sessions.sub(removed as f64);
445            self.metrics.closed_sessions.add(removed as u64);
446        }
447
448        removed
449    }
450}
451
452/// Manager metrics
453struct ManagerMetrics {
454    total_sessions: Counter,
455    active_sessions: Gauge,
456    closed_sessions: Counter,
457}
458
459impl ManagerMetrics {
460    fn new() -> Self {
461        Self {
462            total_sessions: Counter::new("websocket.manager.total_sessions".to_string()),
463            active_sessions: Gauge::new("websocket.manager.active_sessions".to_string()),
464            closed_sessions: Counter::new("websocket.manager.closed_sessions".to_string()),
465        }
466    }
467}
468
469/// Manager statistics
470#[derive(Debug, Clone, Serialize, Deserialize)]
471pub struct ManagerStats {
472    pub active_sessions: usize,
473    pub total_sessions: u64,
474    pub closed_sessions: u64,
475    pub max_connections: usize,
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[tokio::test]
483    async fn test_websocket_session_creation() {
484        let config = WebSocketConfig::default();
485        let session = WebSocketSession::new("test-session".to_string(), config);
486        assert_eq!(session.id, "test-session");
487        assert!(session.is_healthy());
488    }
489
490    #[tokio::test]
491    async fn test_query_lifecycle() {
492        let config = WebSocketConfig::default();
493        let session = WebSocketSession::new("test-session".to_string(), config);
494
495        // Start query
496        let mut rx = session
497            .start_query("q1".to_string(), "SELECT * WHERE { ?s ?p ?o }".to_string())
498            .await
499            .unwrap();
500
501        // Stream results
502        let variables = vec![
503            Variable::new("s").unwrap(),
504            Variable::new("p").unwrap(),
505            Variable::new("o").unwrap(),
506        ];
507        let bindings = vec![]; // Empty results for test
508
509        // Stream in background
510        let session_arc = Arc::new(session);
511        let session_ref = Arc::clone(&session_arc);
512        tokio::spawn(async move {
513            session_ref
514                .stream_results("q1", variables, bindings)
515                .await
516                .unwrap();
517        });
518
519        // Receive completion
520        let msg = rx.recv().await.unwrap();
521        match msg {
522            WebSocketMessage::QueryComplete { id, total_results } => {
523                assert_eq!(id, "q1");
524                assert_eq!(total_results, 0);
525            }
526            _ => panic!("Expected QueryComplete message"),
527        }
528    }
529
530    #[tokio::test]
531    async fn test_query_cancellation() {
532        let config = WebSocketConfig::default();
533        let session = WebSocketSession::new("test-session".to_string(), config);
534
535        // Start query
536        let _rx = session
537            .start_query("q1".to_string(), "SELECT * WHERE { ?s ?p ?o }".to_string())
538            .await
539            .unwrap();
540
541        // Cancel query
542        session.cancel_query("q1").await.unwrap();
543
544        // Verify query is removed
545        let queries = session.active_queries.read().await;
546        assert!(!queries.contains_key("q1"));
547    }
548
549    #[tokio::test]
550    async fn test_manager() {
551        let config = WebSocketConfig::default();
552        let manager = WebSocketManager::new(config);
553
554        // Create session
555        let session = manager.create_session("s1".to_string()).await.unwrap();
556        assert_eq!(session.id, "s1");
557
558        // Get session
559        let retrieved = manager.get_session("s1").await.unwrap();
560        assert_eq!(retrieved.id, "s1");
561
562        // Get stats
563        let stats = manager.statistics().await;
564        assert_eq!(stats.active_sessions, 1);
565        assert_eq!(stats.total_sessions, 1);
566
567        // Remove session
568        manager.remove_session("s1").await.unwrap();
569        let stats = manager.statistics().await;
570        assert_eq!(stats.active_sessions, 0);
571    }
572}