1use crate::algebra::Variable;
7use anyhow::{anyhow, Result};
8use scirs2_core::metrics::{Counter, Gauge, Timer};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::{mpsc, RwLock};
14
15#[derive(Debug, Clone)]
17pub struct WebSocketConfig {
18 pub max_message_size: usize,
20 pub buffer_size: usize,
22 pub ping_interval: Duration,
24 pub connection_timeout: Duration,
26 pub max_connections: usize,
28 pub enable_compression: bool,
30 pub batch_size: usize,
32}
33
34impl Default for WebSocketConfig {
35 fn default() -> Self {
36 Self {
37 max_message_size: 16 * 1024 * 1024, buffer_size: 10000,
39 ping_interval: Duration::from_secs(30),
40 connection_timeout: Duration::from_secs(300),
41 max_connections: 1000,
42 enable_compression: true,
43 batch_size: 100,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "type")]
51pub enum WebSocketMessage {
52 Query {
54 id: String,
55 sparql: String,
56 bindings: Option<HashMap<String, String>>,
57 },
58 ResultBatch {
60 id: String,
61 variables: Vec<String>,
62 solutions: Vec<HashMap<String, String>>,
63 more: bool,
64 },
65 QueryComplete { id: String, total_results: usize },
67 QueryError { id: String, error: String },
69 CancelQuery { id: String },
71 QueryCancelled { id: String },
73 Ping,
75 Pong,
77 Stats { stats: ConnectionStats },
79}
80
81pub struct WebSocketSession {
83 id: String,
85 config: WebSocketConfig,
87 active_queries: Arc<RwLock<HashMap<String, QuerySession>>>,
89 metrics: Arc<SessionMetrics>,
91 start_time: Instant,
93}
94
95impl WebSocketSession {
96 pub fn new(id: String, config: WebSocketConfig) -> Self {
98 Self {
99 id,
100 config,
101 active_queries: Arc::new(RwLock::new(HashMap::new())),
102 metrics: Arc::new(SessionMetrics::new()),
103 start_time: Instant::now(),
104 }
105 }
106
107 pub async fn start_query(
109 &self,
110 query_id: String,
111 sparql: String,
112 ) -> Result<mpsc::Receiver<WebSocketMessage>> {
113 let (tx, rx) = mpsc::channel(self.config.buffer_size);
114
115 let session = QuerySession {
117 id: query_id.clone(),
118 sparql: sparql.clone(),
119 start_time: Instant::now(),
120 results_sent: 0,
121 cancelled: false,
122 sender: tx.clone(),
123 };
124
125 {
127 let mut queries = self.active_queries.write().await;
128 if queries.len() >= self.config.max_connections {
129 return Err(anyhow!("Maximum concurrent queries reached"));
130 }
131 queries.insert(query_id.clone(), session);
132 }
133
134 self.metrics.active_queries.add(1.0);
135 self.metrics.total_queries.inc();
136
137 Ok(rx)
138 }
139
140 pub async fn stream_results(
144 &self,
145 query_id: &str,
146 variables: Vec<Variable>,
147 bindings: Vec<crate::algebra::Binding>,
148 ) -> Result<()> {
149 let query_session = {
150 let queries = self.active_queries.read().await;
151 queries
152 .get(query_id)
153 .ok_or_else(|| anyhow!("Query not found: {}", query_id))?
154 .clone()
155 };
156
157 if query_session.cancelled {
158 return Ok(());
159 }
160
161 let var_names: Vec<String> = variables.iter().map(|v| v.to_string()).collect();
163
164 for batch in bindings.chunks(self.config.batch_size) {
166 if query_session.is_cancelled() {
167 break;
168 }
169
170 let solution_maps: Vec<HashMap<String, String>> = batch
172 .iter()
173 .map(|binding| {
174 binding
175 .iter()
176 .map(|(var, term)| (var.to_string(), format!("{:?}", term)))
177 .collect()
178 })
179 .collect();
180
181 let message = WebSocketMessage::ResultBatch {
182 id: query_id.to_string(),
183 variables: var_names.clone(),
184 solutions: solution_maps,
185 more: true,
186 };
187
188 query_session
189 .sender
190 .send(message)
191 .await
192 .map_err(|e| anyhow!("Failed to send results: {}", e))?;
193
194 self.metrics.results_sent.add(batch.len() as u64);
196
197 {
199 let mut queries = self.active_queries.write().await;
200 if let Some(session) = queries.get_mut(query_id) {
201 session.results_sent += batch.len();
202 }
203 }
204 }
205
206 let message = WebSocketMessage::QueryComplete {
208 id: query_id.to_string(),
209 total_results: bindings.len(),
210 };
211
212 query_session
213 .sender
214 .send(message)
215 .await
216 .map_err(|e| anyhow!("Failed to send completion: {}", e))?;
217
218 self.complete_query(query_id).await;
220
221 Ok(())
222 }
223
224 pub async fn cancel_query(&self, query_id: &str) -> Result<()> {
226 let mut queries = self.active_queries.write().await;
227 if let Some(session) = queries.get_mut(query_id) {
228 session.cancelled = true;
229
230 let message = WebSocketMessage::QueryCancelled {
231 id: query_id.to_string(),
232 };
233
234 let _ = session.sender.send(message).await;
235
236 self.metrics.queries_cancelled.inc();
237 }
238
239 queries.remove(query_id);
240 self.metrics.active_queries.sub(1.0);
241
242 Ok(())
243 }
244
245 async fn complete_query(&self, query_id: &str) {
247 let mut queries = self.active_queries.write().await;
248 if let Some(session) = queries.remove(query_id) {
249 let duration = session.start_time.elapsed();
250 self.metrics.query_duration.observe(duration);
251 self.metrics.active_queries.sub(1.0);
252 self.metrics.completed_queries.inc();
253 }
254 }
255
256 pub async fn send_error(&self, query_id: &str, error: String) -> Result<()> {
258 let queries = self.active_queries.read().await;
259 if let Some(session) = queries.get(query_id) {
260 let message = WebSocketMessage::QueryError {
261 id: query_id.to_string(),
262 error,
263 };
264
265 let _ = session.sender.send(message).await;
266 self.metrics.query_errors.inc();
267 }
268
269 Ok(())
270 }
271
272 pub async fn statistics(&self) -> ConnectionStats {
274 let queries = self.active_queries.read().await;
275 let stats = self.metrics.query_duration.get_stats();
276
277 ConnectionStats {
278 session_id: self.id.clone(),
279 uptime: self.start_time.elapsed(),
280 active_queries: queries.len(),
281 total_queries: self.metrics.total_queries.get(),
282 completed_queries: self.metrics.completed_queries.get(),
283 cancelled_queries: self.metrics.queries_cancelled.get(),
284 failed_queries: self.metrics.query_errors.get(),
285 results_sent: self.metrics.results_sent.get(),
286 average_query_duration: stats.mean,
287 }
288 }
289
290 pub fn is_healthy(&self) -> bool {
292 self.start_time.elapsed() < self.config.connection_timeout
293 }
294}
295
296#[derive(Clone)]
298#[allow(dead_code)]
299struct QuerySession {
300 id: String,
301 sparql: String,
302 start_time: Instant,
303 results_sent: usize,
304 cancelled: bool,
305 sender: mpsc::Sender<WebSocketMessage>,
306}
307
308impl QuerySession {
309 fn is_cancelled(&self) -> bool {
310 self.cancelled
311 }
312}
313
314struct SessionMetrics {
316 total_queries: Counter,
317 active_queries: Gauge,
318 completed_queries: Counter,
319 queries_cancelled: Counter,
320 query_errors: Counter,
321 results_sent: Counter,
322 query_duration: Timer,
323}
324
325impl SessionMetrics {
326 fn new() -> Self {
327 Self {
328 total_queries: Counter::new("websocket.total_queries".to_string()),
329 active_queries: Gauge::new("websocket.active_queries".to_string()),
330 completed_queries: Counter::new("websocket.completed_queries".to_string()),
331 queries_cancelled: Counter::new("websocket.queries_cancelled".to_string()),
332 query_errors: Counter::new("websocket.query_errors".to_string()),
333 results_sent: Counter::new("websocket.results_sent".to_string()),
334 query_duration: Timer::new("websocket.query_duration".to_string()),
335 }
336 }
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct ConnectionStats {
342 pub session_id: String,
343 #[serde(serialize_with = "serialize_duration")]
344 pub uptime: Duration,
345 pub active_queries: usize,
346 pub total_queries: u64,
347 pub completed_queries: u64,
348 pub cancelled_queries: u64,
349 pub failed_queries: u64,
350 pub results_sent: u64,
351 pub average_query_duration: f64,
352}
353
354fn serialize_duration<S>(duration: &Duration, serializer: S) -> std::result::Result<S::Ok, S::Error>
355where
356 S: serde::Serializer,
357{
358 serializer.serialize_f64(duration.as_secs_f64())
359}
360
361pub struct WebSocketManager {
363 config: WebSocketConfig,
365 sessions: Arc<RwLock<HashMap<String, Arc<WebSocketSession>>>>,
367 metrics: Arc<ManagerMetrics>,
369}
370
371impl WebSocketManager {
372 pub fn new(config: WebSocketConfig) -> Self {
374 Self {
375 config,
376 sessions: Arc::new(RwLock::new(HashMap::new())),
377 metrics: Arc::new(ManagerMetrics::new()),
378 }
379 }
380
381 pub async fn create_session(&self, session_id: String) -> Result<Arc<WebSocketSession>> {
383 let mut sessions = self.sessions.write().await;
384
385 if sessions.len() >= self.config.max_connections {
386 return Err(anyhow!("Maximum connections reached"));
387 }
388
389 let session = Arc::new(WebSocketSession::new(
390 session_id.clone(),
391 self.config.clone(),
392 ));
393 sessions.insert(session_id, session.clone());
394
395 self.metrics.active_sessions.add(1.0);
396 self.metrics.total_sessions.inc();
397
398 Ok(session)
399 }
400
401 pub async fn get_session(&self, session_id: &str) -> Option<Arc<WebSocketSession>> {
403 let sessions = self.sessions.read().await;
404 sessions.get(session_id).cloned()
405 }
406
407 pub async fn remove_session(&self, session_id: &str) -> Result<()> {
409 let mut sessions = self.sessions.write().await;
410 if sessions.remove(session_id).is_some() {
411 self.metrics.active_sessions.sub(1.0);
412 self.metrics.closed_sessions.inc();
413 }
414 Ok(())
415 }
416
417 pub async fn statistics(&self) -> ManagerStats {
419 let sessions = self.sessions.read().await;
420
421 ManagerStats {
422 active_sessions: sessions.len(),
423 total_sessions: self.metrics.total_sessions.get(),
424 closed_sessions: self.metrics.closed_sessions.get(),
425 max_connections: self.config.max_connections,
426 }
427 }
428
429 pub async fn cleanup_inactive_sessions(&self) -> usize {
431 let mut sessions = self.sessions.write().await;
432 let mut removed = 0;
433
434 sessions.retain(|_, session| {
435 if !session.is_healthy() {
436 removed += 1;
437 false
438 } else {
439 true
440 }
441 });
442
443 if removed > 0 {
444 self.metrics.active_sessions.sub(removed as f64);
445 self.metrics.closed_sessions.add(removed as u64);
446 }
447
448 removed
449 }
450}
451
452struct ManagerMetrics {
454 total_sessions: Counter,
455 active_sessions: Gauge,
456 closed_sessions: Counter,
457}
458
459impl ManagerMetrics {
460 fn new() -> Self {
461 Self {
462 total_sessions: Counter::new("websocket.manager.total_sessions".to_string()),
463 active_sessions: Gauge::new("websocket.manager.active_sessions".to_string()),
464 closed_sessions: Counter::new("websocket.manager.closed_sessions".to_string()),
465 }
466 }
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
471pub struct ManagerStats {
472 pub active_sessions: usize,
473 pub total_sessions: u64,
474 pub closed_sessions: u64,
475 pub max_connections: usize,
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[tokio::test]
483 async fn test_websocket_session_creation() {
484 let config = WebSocketConfig::default();
485 let session = WebSocketSession::new("test-session".to_string(), config);
486 assert_eq!(session.id, "test-session");
487 assert!(session.is_healthy());
488 }
489
490 #[tokio::test]
491 async fn test_query_lifecycle() {
492 let config = WebSocketConfig::default();
493 let session = WebSocketSession::new("test-session".to_string(), config);
494
495 let mut rx = session
497 .start_query("q1".to_string(), "SELECT * WHERE { ?s ?p ?o }".to_string())
498 .await
499 .unwrap();
500
501 let variables = vec![
503 Variable::new("s").unwrap(),
504 Variable::new("p").unwrap(),
505 Variable::new("o").unwrap(),
506 ];
507 let bindings = vec![]; let session_arc = Arc::new(session);
511 let session_ref = Arc::clone(&session_arc);
512 tokio::spawn(async move {
513 session_ref
514 .stream_results("q1", variables, bindings)
515 .await
516 .unwrap();
517 });
518
519 let msg = rx.recv().await.unwrap();
521 match msg {
522 WebSocketMessage::QueryComplete { id, total_results } => {
523 assert_eq!(id, "q1");
524 assert_eq!(total_results, 0);
525 }
526 _ => panic!("Expected QueryComplete message"),
527 }
528 }
529
530 #[tokio::test]
531 async fn test_query_cancellation() {
532 let config = WebSocketConfig::default();
533 let session = WebSocketSession::new("test-session".to_string(), config);
534
535 let _rx = session
537 .start_query("q1".to_string(), "SELECT * WHERE { ?s ?p ?o }".to_string())
538 .await
539 .unwrap();
540
541 session.cancel_query("q1").await.unwrap();
543
544 let queries = session.active_queries.read().await;
546 assert!(!queries.contains_key("q1"));
547 }
548
549 #[tokio::test]
550 async fn test_manager() {
551 let config = WebSocketConfig::default();
552 let manager = WebSocketManager::new(config);
553
554 let session = manager.create_session("s1".to_string()).await.unwrap();
556 assert_eq!(session.id, "s1");
557
558 let retrieved = manager.get_session("s1").await.unwrap();
560 assert_eq!(retrieved.id, "s1");
561
562 let stats = manager.statistics().await;
564 assert_eq!(stats.active_sessions, 1);
565 assert_eq!(stats.total_sessions, 1);
566
567 manager.remove_session("s1").await.unwrap();
569 let stats = manager.statistics().await;
570 assert_eq!(stats.active_sessions, 0);
571 }
572}