use crate::{
auth::{AuthUser, SessionManager},
error::{FusekiError, FusekiResult},
metrics::MetricsService,
store::Store,
};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::{IntoResponse, Response},
};
use dashmap::DashMap;
use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::{Duration, Instant},
};
use tokio::{
sync::{broadcast, mpsc},
time::interval,
};
use tracing::{error, info, warn};
use uuid::Uuid;
pub struct SubscriptionManager {
subscriptions: Arc<DashMap<String, Subscription>>,
query_subscriptions: Arc<DashMap<u64, HashSet<String>>>,
connections: Arc<DashMap<String, ConnectionInfo>>,
change_broadcaster: broadcast::Sender<ChangeNotification>,
query_executor: Arc<QueryExecutor>,
metrics: Arc<MetricsService>,
config: Arc<WebSocketConfig>,
session_manager: Option<Arc<SessionManager>>,
}
#[derive(Debug, Clone)]
pub struct WebSocketConfig {
pub max_subscriptions_per_connection: usize,
pub max_total_subscriptions: usize,
pub evaluation_interval: Duration,
pub connection_timeout: Duration,
pub max_message_size: usize,
pub enable_compression: bool,
pub heartbeat_interval: Duration,
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self {
max_subscriptions_per_connection: 100,
max_total_subscriptions: 10000,
evaluation_interval: Duration::from_secs(1),
connection_timeout: Duration::from_secs(300),
max_message_size: 10 * 1024 * 1024, enable_compression: true,
heartbeat_interval: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone)]
pub struct Subscription {
pub id: String,
pub connection_id: String,
pub query: String,
pub parameters: QueryParameters,
pub filter: Option<NotificationFilter>,
pub last_result_hash: Option<u64>,
pub last_result: Option<QueryResult>,
pub created_at: Instant,
pub last_evaluated: Instant,
pub notification_count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryParameters {
pub default_graph_uri: Vec<String>,
pub named_graph_uri: Vec<String>,
pub timeout_ms: Option<u64>,
pub format: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotificationFilter {
pub min_change_threshold: Option<f64>,
pub monitored_variables: Option<Vec<String>>,
pub debounce_ms: Option<u64>,
pub rate_limit: Option<u32>,
}
#[derive(Debug)]
pub struct ConnectionInfo {
pub id: String,
pub user: Option<AuthUser>,
pub subscription_ids: HashSet<String>,
pub connected_at: Instant,
pub last_activity: Instant,
pub sender: mpsc::Sender<Message>,
pub state: ConnectionState,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionState {
Connected,
Authenticated,
Closing,
Closed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChangeNotification {
pub graphs: Vec<String>,
pub change_type: ChangeType,
pub timestamp: std::time::SystemTime,
pub details: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ChangeType {
Insert,
Delete,
Update,
Clear,
Load,
Transaction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum WsMessage {
Subscribe {
query: String,
parameters: QueryParameters,
filter: Option<NotificationFilter>,
},
Unsubscribe { subscription_id: String },
QueryUpdate {
subscription_id: String,
result: QueryResult,
changes: Option<ResultChanges>,
},
Ack {
message_id: String,
success: bool,
error: Option<String>,
},
Error {
code: String,
message: String,
details: Option<serde_json::Value>,
},
Ping { timestamp: u64 },
Pong { timestamp: u64 },
Auth { token: String },
Subscribed {
subscription_id: String,
query: String,
},
Unsubscribed { subscription_id: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
pub bindings: Vec<HashMap<String, serde_json::Value>>,
pub metadata: ResultMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultMetadata {
pub execution_time_ms: u64,
pub result_count: usize,
pub result_hash: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultChanges {
pub added: Vec<HashMap<String, serde_json::Value>>,
pub removed: Vec<HashMap<String, serde_json::Value>>,
pub modified: Vec<(
HashMap<String, serde_json::Value>,
HashMap<String, serde_json::Value>,
)>,
}
pub struct QueryExecutor {
store: Arc<Store>,
executor: tokio::runtime::Handle,
}
impl SubscriptionManager {
pub fn new(store: Arc<Store>, metrics: Arc<MetricsService>, config: WebSocketConfig) -> Self {
let (tx, _rx) = broadcast::channel(1000);
Self {
subscriptions: Arc::new(DashMap::new()),
query_subscriptions: Arc::new(DashMap::new()),
connections: Arc::new(DashMap::new()),
change_broadcaster: tx,
query_executor: Arc::new(QueryExecutor::new(store)),
metrics,
config: Arc::new(config),
session_manager: None,
}
}
pub fn new_with_auth(
store: Arc<Store>,
metrics: Arc<MetricsService>,
config: WebSocketConfig,
session_manager: Arc<SessionManager>,
) -> Self {
let (tx, _rx) = broadcast::channel(1000);
Self {
subscriptions: Arc::new(DashMap::new()),
query_subscriptions: Arc::new(DashMap::new()),
connections: Arc::new(DashMap::new()),
change_broadcaster: tx,
query_executor: Arc::new(QueryExecutor::new(store)),
metrics,
config: Arc::new(config),
session_manager: Some(session_manager),
}
}
pub async fn start(&self) {
info!("Starting WebSocket subscription manager");
let manager = self.clone();
tokio::spawn(async move {
manager.evaluation_loop().await;
});
let manager = self.clone();
tokio::spawn(async move {
manager.cleanup_loop().await;
});
}
pub async fn handle_websocket(&self, ws: WebSocketUpgrade, user: Option<AuthUser>) -> Response {
let connection_id = Uuid::new_v4().to_string();
let manager = self.clone();
ws.on_upgrade(move |socket| async move {
if let Err(e) = manager.handle_connection(socket, connection_id, user).await {
error!("WebSocket connection error: {}", e);
}
})
}
async fn handle_connection(
&self,
ws: WebSocket,
connection_id: String,
user: Option<AuthUser>,
) -> FusekiResult<()> {
info!("New WebSocket connection: {}", connection_id);
self.metrics
.increment_counter("websocket.connections.total", 1)
.await;
let (sender, receiver) = ws.split();
let (tx, rx) = mpsc::channel(100);
let conn_info = ConnectionInfo {
id: connection_id.clone(),
user,
subscription_ids: HashSet::new(),
connected_at: Instant::now(),
last_activity: Instant::now(),
sender: tx,
state: ConnectionState::Connected,
};
self.connections.insert(connection_id.clone(), conn_info);
let sender_task = tokio::spawn(Self::message_sender(sender, rx));
let receiver_task = tokio::spawn(
self.clone()
.message_receiver(receiver, connection_id.clone()),
);
let _ = tokio::try_join!(sender_task, receiver_task);
self.cleanup_connection(&connection_id).await;
info!("WebSocket connection closed: {}", connection_id);
self.metrics
.increment_counter("websocket.connections.closed", 1)
.await;
Ok(())
}
async fn message_sender(
mut sender: SplitSink<WebSocket, Message>,
mut receiver: mpsc::Receiver<Message>,
) {
while let Some(msg) = receiver.recv().await {
if sender.send(msg).await.is_err() {
break;
}
}
}
async fn message_receiver(self, mut receiver: SplitStream<WebSocket>, connection_id: String) {
while let Some(result) = receiver.next().await {
match result {
Ok(msg) => {
if let Err(e) = self.handle_message(msg, &connection_id).await {
error!("Error handling message: {}", e);
self.send_error(&connection_id, "message_error", &e.to_string())
.await;
}
}
Err(e) => {
error!("WebSocket receive error: {}", e);
break;
}
}
}
}
async fn handle_message(&self, msg: Message, connection_id: &str) -> FusekiResult<()> {
if let Some(mut conn) = self.connections.get_mut(connection_id) {
conn.last_activity = Instant::now();
}
match msg {
Message::Text(text) => {
let value: serde_json::Value = serde_json::from_str(&text)
.map_err(|e| FusekiError::bad_request(format!("Invalid JSON: {e}")))?;
let ws_msg: WsMessage = serde_json::from_value(value)
.map_err(|e| FusekiError::bad_request(format!("Invalid message: {e}")))?;
self.handle_ws_message(ws_msg, connection_id).await?;
}
Message::Binary(data) => {
if self.config.enable_compression {
let decompressed = Self::decompress_message(&data)?;
let value: serde_json::Value = serde_json::from_slice(&decompressed)
.map_err(|e| FusekiError::bad_request(format!("Invalid JSON: {e}")))?;
let ws_msg: WsMessage = serde_json::from_value(value)
.map_err(|e| FusekiError::bad_request(format!("Invalid message: {e}")))?;
self.handle_ws_message(ws_msg, connection_id).await?;
} else {
return Err(FusekiError::bad_request("Binary messages not supported"));
}
}
Message::Ping(data) => {
self.send_message(connection_id, Message::Pong(data))
.await?;
}
Message::Pong(_) => {
}
Message::Close(_) => {
if let Some(mut conn) = self.connections.get_mut(connection_id) {
conn.state = ConnectionState::Closing;
}
}
}
Ok(())
}
async fn handle_ws_message(&self, msg: WsMessage, connection_id: &str) -> FusekiResult<()> {
match msg {
WsMessage::Subscribe {
query,
parameters,
filter,
} => {
self.handle_subscribe(connection_id, query, parameters, filter)
.await?;
}
WsMessage::Unsubscribe { subscription_id } => {
self.handle_unsubscribe(connection_id, &subscription_id)
.await?;
}
WsMessage::Ping { timestamp } => {
self.send_ws_message(connection_id, WsMessage::Pong { timestamp })
.await?;
}
WsMessage::Auth { token } => {
self.handle_auth(connection_id, &token).await?;
}
_ => {
return Err(FusekiError::bad_request("Unexpected message type"));
}
}
Ok(())
}
async fn handle_subscribe(
&self,
connection_id: &str,
query: String,
parameters: QueryParameters,
filter: Option<NotificationFilter>,
) -> FusekiResult<()> {
self.check_subscription_limits(connection_id)?;
Self::validate_subscription_query(&query)?;
let subscription_id = Uuid::new_v4().to_string();
let subscription = Subscription {
id: subscription_id.clone(),
connection_id: connection_id.to_string(),
query: query.clone(),
parameters,
filter,
last_result_hash: None,
last_result: None,
created_at: Instant::now(),
last_evaluated: Instant::now(),
notification_count: 0,
};
self.subscriptions
.insert(subscription_id.clone(), subscription.clone());
let query_hash = Self::hash_query(&query);
self.query_subscriptions
.entry(query_hash)
.or_default()
.insert(subscription_id.clone());
if let Some(mut conn) = self.connections.get_mut(connection_id) {
conn.subscription_ids.insert(subscription_id.clone());
}
self.send_ws_message(
connection_id,
WsMessage::Subscribed {
subscription_id: subscription_id.clone(),
query: query.clone(),
},
)
.await?;
self.evaluate_subscription(&subscription_id).await?;
self.metrics
.increment_counter("websocket.subscriptions.created", 1)
.await;
info!(
"Created subscription {} for connection {}",
subscription_id, connection_id
);
Ok(())
}
async fn handle_unsubscribe(
&self,
connection_id: &str,
subscription_id: &str,
) -> FusekiResult<()> {
match self.subscriptions.get(subscription_id) {
Some(sub) => {
if sub.connection_id != connection_id {
return Err(FusekiError::forbidden("Not subscription owner"));
}
}
_ => {
return Err(FusekiError::not_found("Subscription not found"));
}
}
self.remove_subscription(subscription_id).await;
self.send_ws_message(
connection_id,
WsMessage::Unsubscribed {
subscription_id: subscription_id.to_string(),
},
)
.await?;
Ok(())
}
async fn handle_auth(&self, connection_id: &str, token: &str) -> FusekiResult<()> {
let authenticated = if let Some(ref session_manager) = self.session_manager {
match session_manager.validate_jwt_token(token) {
Ok(validation) => {
if let Some(mut conn) = self.connections.get_mut(connection_id) {
conn.state = ConnectionState::Authenticated;
conn.user = Some(AuthUser(validation.user));
}
true
}
Err(_) => {
match session_manager.validate_session(token).await {
Ok(crate::auth::types::AuthResult::Authenticated(user)) => {
if let Some(mut conn) = self.connections.get_mut(connection_id) {
conn.state = ConnectionState::Authenticated;
conn.user = Some(AuthUser(user));
}
true
}
_ => false,
}
}
}
} else {
if let Some(mut conn) = self.connections.get_mut(connection_id) {
conn.state = ConnectionState::Authenticated;
}
true
};
if authenticated {
self.send_ws_message(
connection_id,
WsMessage::Ack {
message_id: Uuid::new_v4().to_string(),
success: true,
error: None,
},
)
.await?;
Ok(())
} else {
self.send_ws_message(
connection_id,
WsMessage::Error {
code: "AUTH_FAILED".to_string(),
message: "Authentication failed: Invalid token".to_string(),
details: None,
},
)
.await?;
Err(FusekiError::authentication("Invalid authentication token"))
}
}
fn check_subscription_limits(&self, connection_id: &str) -> FusekiResult<()> {
if let Some(conn) = self.connections.get(connection_id) {
if conn.subscription_ids.len() >= self.config.max_subscriptions_per_connection {
return Err(FusekiError::bad_request(
"Maximum subscriptions per connection exceeded",
));
}
}
if self.subscriptions.len() >= self.config.max_total_subscriptions {
return Err(FusekiError::service_unavailable(
"Maximum total subscriptions exceeded",
));
}
Ok(())
}
pub fn validate_subscription_query(query: &str) -> FusekiResult<()> {
if query.trim().is_empty() {
return Err(FusekiError::bad_request("Empty query"));
}
let query_lower = query.to_lowercase();
if !query_lower.contains("select") && !query_lower.contains("construct") {
return Err(FusekiError::bad_request(
"Only SELECT and CONSTRUCT queries supported for subscriptions",
));
}
if !query_lower.contains("limit") {
return Err(FusekiError::bad_request(
"Subscription queries must include LIMIT clause",
));
}
Ok(())
}
async fn evaluation_loop(&self) {
let mut interval = interval(self.config.evaluation_interval);
loop {
interval.tick().await;
let subscription_ids: Vec<String> = self
.subscriptions
.iter()
.map(|entry| entry.key().clone())
.collect();
for id in subscription_ids {
if let Err(e) = self.evaluate_subscription(&id).await {
error!("Subscription evaluation error for {}: {}", id, e);
}
}
}
}
async fn evaluate_subscription(&self, subscription_id: &str) -> FusekiResult<()> {
let subscription = match self.subscriptions.get(subscription_id) {
Some(sub) => sub.clone(),
None => return Ok(()), };
if let Some(filter) = &subscription.filter {
if let Some(rate_limit) = filter.rate_limit {
let notifications_per_minute = subscription.notification_count as f64
/ subscription.created_at.elapsed().as_secs_f64()
* 60.0;
if notifications_per_minute > rate_limit as f64 {
return Ok(()); }
}
}
let result = self
.query_executor
.execute_query(&subscription.query, &subscription.parameters)
.await?;
let result_hash = Self::hash_result(&result);
let has_changed = subscription
.last_result_hash
.map(|h| h != result_hash)
.unwrap_or(true);
if !has_changed {
return Ok(()); }
if let Some(filter) = &subscription.filter {
if !self
.apply_notification_filter(&subscription, &result, filter)
.await?
{
return Ok(()); }
}
let changes = subscription
.last_result
.as_ref()
.map(|old_result| Self::calculate_result_changes(old_result, &result));
self.send_ws_message(
&subscription.connection_id,
WsMessage::QueryUpdate {
subscription_id: subscription_id.to_string(),
result: result.clone(),
changes,
},
)
.await?;
if let Some(mut sub) = self.subscriptions.get_mut(subscription_id) {
sub.last_result_hash = Some(result_hash);
sub.last_result = Some(result.clone());
sub.last_evaluated = Instant::now();
sub.notification_count += 1;
}
self.metrics
.increment_counter("websocket.notifications.sent", 1)
.await;
Ok(())
}
async fn apply_notification_filter(
&self,
subscription: &Subscription,
result: &QueryResult,
filter: &NotificationFilter,
) -> FusekiResult<bool> {
if let Some(threshold) = filter.min_change_threshold {
let change_percentage = if let Some(ref old_result) = subscription.last_result {
let changes = Self::calculate_result_changes(old_result, result);
let total = old_result.bindings.len().max(result.bindings.len());
Self::calculate_change_percentage(&changes, total)
} else {
100.0 };
if change_percentage < threshold {
return Ok(false);
}
}
if let Some(debounce_ms) = filter.debounce_ms {
let time_since_last = subscription.last_evaluated.elapsed().as_millis() as u64;
if time_since_last < debounce_ms {
return Ok(false);
}
}
Ok(true)
}
async fn cleanup_loop(&self) {
let mut interval = interval(Duration::from_secs(60));
loop {
interval.tick().await;
let now = Instant::now();
let timeout = self.config.connection_timeout;
let expired: Vec<String> = self
.connections
.iter()
.filter(|entry| now.duration_since(entry.last_activity) > timeout)
.map(|entry| entry.key().clone())
.collect();
for connection_id in expired {
warn!("Cleaning up expired connection: {}", connection_id);
self.cleanup_connection(&connection_id).await;
}
}
}
async fn cleanup_connection(&self, connection_id: &str) {
if let Some((_, conn)) = self.connections.remove(connection_id) {
for sub_id in &conn.subscription_ids {
self.remove_subscription(sub_id).await;
}
}
}
async fn remove_subscription(&self, subscription_id: &str) {
if let Some((_, sub)) = self.subscriptions.remove(subscription_id) {
let query_hash = Self::hash_query(&sub.query);
if let Some(mut subs) = self.query_subscriptions.get_mut(&query_hash) {
subs.remove(subscription_id);
if subs.is_empty() {
drop(subs);
self.query_subscriptions.remove(&query_hash);
}
}
if let Some(mut conn) = self.connections.get_mut(&sub.connection_id) {
conn.subscription_ids.remove(subscription_id);
}
}
}
async fn send_message(&self, connection_id: &str, msg: Message) -> FusekiResult<()> {
if let Some(conn) = self.connections.get(connection_id) {
conn.sender
.send(msg)
.await
.map_err(|_| FusekiError::internal("Failed to send message"))?;
}
Ok(())
}
async fn send_ws_message(&self, connection_id: &str, msg: WsMessage) -> FusekiResult<()> {
let json = serde_json::to_string(&msg)
.map_err(|e| FusekiError::internal(format!("Serialization error: {e}")))?;
self.send_message(connection_id, Message::Text(json.into()))
.await
}
async fn send_error(&self, connection_id: &str, code: &str, message: &str) {
let _ = self
.send_ws_message(
connection_id,
WsMessage::Error {
code: code.to_string(),
message: message.to_string(),
details: None,
},
)
.await;
}
fn hash_query(query: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query.hash(&mut hasher);
hasher.finish()
}
fn hash_result(result: &QueryResult) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for binding in &result.bindings {
for (k, v) in binding {
k.hash(&mut hasher);
v.to_string().hash(&mut hasher);
}
}
hasher.finish()
}
fn calculate_result_changes(
old_result: &QueryResult,
new_result: &QueryResult,
) -> ResultChanges {
let mut added = Vec::new();
let mut removed = Vec::new();
let mut modified = Vec::new();
let old_bindings_set: std::collections::HashSet<String> = old_result
.bindings
.iter()
.map(|b| serde_json::to_string(b).unwrap_or_default())
.collect();
let new_bindings_set: std::collections::HashSet<String> = new_result
.bindings
.iter()
.map(|b| serde_json::to_string(b).unwrap_or_default())
.collect();
for binding in &new_result.bindings {
let binding_str = serde_json::to_string(binding).unwrap_or_default();
if !old_bindings_set.contains(&binding_str) {
if let Some(old_binding) =
Self::find_matching_binding(binding, &old_result.bindings)
{
modified.push((old_binding.clone(), binding.clone()));
} else {
added.push(binding.clone());
}
}
}
for binding in &old_result.bindings {
let binding_str = serde_json::to_string(binding).unwrap_or_default();
if !new_bindings_set.contains(&binding_str) {
if !modified.iter().any(|(old, _)| {
old.keys().collect::<Vec<_>>() == binding.keys().collect::<Vec<_>>()
}) {
removed.push(binding.clone());
}
}
}
ResultChanges {
added,
removed,
modified,
}
}
fn find_matching_binding<'a>(
binding: &HashMap<String, serde_json::Value>,
bindings: &'a [HashMap<String, serde_json::Value>],
) -> Option<&'a HashMap<String, serde_json::Value>> {
let binding_keys: std::collections::HashSet<&String> = binding.keys().collect();
for candidate in bindings {
let candidate_keys: std::collections::HashSet<&String> = candidate.keys().collect();
if binding_keys == candidate_keys {
let values_differ = binding
.iter()
.any(|(k, v)| !candidate.get(k).is_some_and(|cv| cv == v));
if values_differ {
return Some(candidate);
}
}
}
None
}
fn calculate_change_percentage(changes: &ResultChanges, total_bindings: usize) -> f64 {
if total_bindings == 0 {
return 100.0; }
let changed_count = changes.added.len() + changes.removed.len() + changes.modified.len();
(changed_count as f64 / total_bindings as f64) * 100.0
}
fn decompress_message(data: &[u8]) -> FusekiResult<Vec<u8>> {
use flate2::read::GzDecoder;
use std::io::Read;
let mut decoder = GzDecoder::new(data);
let mut decompressed = Vec::new();
decoder
.read_to_end(&mut decompressed)
.map_err(|e| FusekiError::bad_request(format!("Decompression error: {e}")))?;
Ok(decompressed)
}
pub async fn handle_store_change(&self, notification: ChangeNotification) {
let _ = self.change_broadcaster.send(notification.clone());
let subscription_ids: Vec<String> = self
.subscriptions
.iter()
.map(|entry| entry.key().clone())
.collect();
for sub_id in subscription_ids {
if let Err(e) = self.evaluate_subscription(&sub_id).await {
error!("Error evaluating subscription after change: {}", e);
}
}
}
}
impl Clone for SubscriptionManager {
fn clone(&self) -> Self {
Self {
subscriptions: Arc::clone(&self.subscriptions),
query_subscriptions: Arc::clone(&self.query_subscriptions),
connections: Arc::clone(&self.connections),
change_broadcaster: self.change_broadcaster.clone(),
query_executor: Arc::clone(&self.query_executor),
metrics: Arc::clone(&self.metrics),
config: Arc::clone(&self.config),
session_manager: self.session_manager.clone(),
}
}
}
impl QueryExecutor {
pub fn new(store: Arc<Store>) -> Self {
Self {
store,
executor: tokio::runtime::Handle::current(),
}
}
pub async fn execute_query(
&self,
query: &str,
parameters: &QueryParameters,
) -> FusekiResult<QueryResult> {
let start = Instant::now();
let bindings = self.execute_sparql_query(query, parameters).await?;
let execution_time = start.elapsed().as_millis() as u64;
let result_count = bindings.len();
let result_hash = SubscriptionManager::hash_result(&QueryResult {
bindings: bindings.clone(),
metadata: ResultMetadata {
execution_time_ms: execution_time,
result_count,
result_hash: 0,
},
});
Ok(QueryResult {
bindings,
metadata: ResultMetadata {
execution_time_ms: execution_time,
result_count,
result_hash,
},
})
}
async fn execute_sparql_query(
&self,
query: &str,
_parameters: &QueryParameters,
) -> FusekiResult<Vec<HashMap<String, serde_json::Value>>> {
let mut bindings = Vec::new();
if query.to_lowercase().contains("select") {
for i in 0..3 {
let mut binding = HashMap::new();
binding.insert(
"subject".to_string(),
serde_json::json!(format!("http://example.org/resource{}", i)),
);
binding.insert(
"predicate".to_string(),
serde_json::json!("http://example.org/property"),
);
binding.insert(
"object".to_string(),
serde_json::json!(format!("Value {}", i)),
);
bindings.push(binding);
}
}
Ok(bindings)
}
}
pub async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<crate::server::AppState>>,
user: Option<AuthUser>,
) -> Response {
if let Some(ref subscription_manager) = state.subscription_manager {
subscription_manager.handle_websocket(ws, user).await
} else {
(
axum::http::StatusCode::SERVICE_UNAVAILABLE,
"WebSocket support not configured",
)
.into_response()
}
}
#[derive(Debug, Deserialize)]
pub struct WebSocketQuery {
pub token: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_validation() {
assert!(SubscriptionManager::validate_subscription_query(
"SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 100"
)
.is_ok());
assert!(SubscriptionManager::validate_subscription_query(
"CONSTRUCT { ?s ?p ?o } WHERE { ?s ?p ?o } LIMIT 10"
)
.is_ok());
assert!(SubscriptionManager::validate_subscription_query("").is_err());
assert!(SubscriptionManager::validate_subscription_query("ASK { ?s ?p ?o }").is_err());
assert!(SubscriptionManager::validate_subscription_query(
"SELECT ?s ?p ?o WHERE { ?s ?p ?o }"
)
.is_err()); }
#[test]
fn test_message_serialization() {
let msg = WsMessage::Subscribe {
query: "SELECT ?s WHERE { ?s ?p ?o } LIMIT 10".to_string(),
parameters: QueryParameters {
default_graph_uri: vec![],
named_graph_uri: vec![],
timeout_ms: Some(5000),
format: "json".to_string(),
},
filter: Some(NotificationFilter {
min_change_threshold: Some(5.0),
monitored_variables: Some(vec!["s".to_string()]),
debounce_ms: Some(1000),
rate_limit: Some(60),
}),
};
let json = serde_json::to_string(&msg).unwrap();
let value: serde_json::Value = serde_json::from_str(&json).unwrap();
let deserialized: WsMessage = serde_json::from_value(value).unwrap();
match deserialized {
WsMessage::Subscribe { query, .. } => {
assert!(query.contains("SELECT"));
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_query_hashing() {
let query1 = "SELECT ?s WHERE { ?s ?p ?o }";
let query2 = "SELECT ?s WHERE { ?s ?p ?o }";
let query3 = "SELECT ?x WHERE { ?x ?y ?z }";
assert_eq!(
SubscriptionManager::hash_query(query1),
SubscriptionManager::hash_query(query2)
);
assert_ne!(
SubscriptionManager::hash_query(query1),
SubscriptionManager::hash_query(query3)
);
}
}