#[cfg(feature = "portal")]
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State,
},
response::IntoResponse,
};
#[cfg(feature = "portal")]
use futures::{SinkExt, StreamExt};
#[cfg(feature = "portal")]
use serde::Deserialize;
#[cfg(feature = "portal")]
use std::{collections::HashMap, sync::Arc, time::Duration};
#[cfg(feature = "portal")]
use tokio::sync::{broadcast, mpsc, RwLock};
#[cfg(feature = "portal")]
use tracing::{debug, error, info, warn};
#[cfg(feature = "portal")]
use uuid::Uuid;
#[cfg(feature = "portal")]
use crate::portal::{
auth::Claims,
auth_db::PortalState,
sync::{SyncConfig, SyncEvent, WsMessage},
sync_db::{
CreateSessionInput, DeviceSessionRepository, SyncQueueRepository, SyncStateRepository,
UpdateSyncStateInput,
},
};
#[cfg(feature = "portal")]
#[derive(Debug)]
pub struct WsConnection {
pub user_id: Uuid,
pub device_id: String,
pub sender: mpsc::Sender<WsMessage>,
}
#[cfg(feature = "portal")]
#[derive(Clone)]
pub struct WsState {
connections: Arc<RwLock<HashMap<Uuid, HashMap<String, mpsc::Sender<WsMessage>>>>>,
broadcast_tx: broadcast::Sender<(Uuid, SyncEvent, Option<String>)>,
config: SyncConfig,
}
#[cfg(feature = "portal")]
impl WsState {
pub fn new(config: SyncConfig) -> Self {
let (broadcast_tx, _) = broadcast::channel(1024);
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
broadcast_tx,
config,
}
}
pub async fn register(
&self,
user_id: Uuid,
device_id: String,
sender: mpsc::Sender<WsMessage>,
) {
let mut connections = self.connections.write().await;
connections
.entry(user_id)
.or_insert_with(HashMap::new)
.insert(device_id, sender);
}
pub async fn unregister(&self, user_id: Uuid, device_id: &str) {
let mut connections = self.connections.write().await;
if let Some(user_connections) = connections.get_mut(&user_id) {
user_connections.remove(device_id);
if user_connections.is_empty() {
connections.remove(&user_id);
}
}
}
pub async fn send_to_device(&self, user_id: Uuid, device_id: &str, message: WsMessage) -> bool {
let connections = self.connections.read().await;
if let Some(user_connections) = connections.get(&user_id) {
if let Some(sender) = user_connections.get(device_id) {
return sender.send(message).await.is_ok();
}
}
false
}
pub async fn broadcast(&self, user_id: Uuid, event: SyncEvent, exclude_device: Option<&str>) {
let connections = self.connections.read().await;
if let Some(user_connections) = connections.get(&user_id) {
for (device_id, sender) in user_connections.iter() {
if exclude_device.map(|e| e != device_id).unwrap_or(true) {
let _ = sender.send(WsMessage::Event(event.clone())).await;
}
}
}
}
pub async fn connected_count(&self, user_id: Uuid) -> usize {
let connections = self.connections.read().await;
connections.get(&user_id).map(|c| c.len()).unwrap_or(0)
}
pub fn subscribe(&self) -> broadcast::Receiver<(Uuid, SyncEvent, Option<String>)> {
self.broadcast_tx.subscribe()
}
pub fn publish(&self, user_id: Uuid, event: SyncEvent, exclude_device: Option<String>) {
let _ = self.broadcast_tx.send((user_id, event, exclude_device));
}
}
#[cfg(feature = "portal")]
#[derive(Debug, Deserialize)]
pub struct WsUpgradeQuery {
pub token: String,
pub device_id: String,
pub device_name: Option<String>,
pub platform: Option<String>,
}
#[cfg(feature = "portal")]
#[derive(Clone)]
pub struct SyncWsState {
pub portal: PortalState,
pub ws: WsState,
}
#[cfg(feature = "portal")]
pub async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<SyncWsState>,
Query(params): Query<WsUpgradeQuery>,
) -> impl IntoResponse {
let claims = match state.portal.auth.validate_token(¶ms.token) {
Ok(claims) => claims,
Err(e) => {
warn!("WebSocket auth failed: {}", e);
return axum::response::Response::builder()
.status(axum::http::StatusCode::UNAUTHORIZED)
.body(axum::body::Body::from("Unauthorized"))
.unwrap()
.into_response();
}
};
info!(
"WebSocket upgrade for user {} device {}",
claims.sub, params.device_id
);
ws.on_upgrade(move |socket| handle_socket(socket, state, claims, params))
}
#[cfg(feature = "portal")]
async fn handle_socket(
socket: WebSocket,
state: SyncWsState,
claims: Claims,
params: WsUpgradeQuery,
) {
let user_id = match Uuid::parse_str(&claims.sub) {
Ok(id) => id,
Err(e) => {
error!("Invalid user ID in claims: {}", e);
return;
}
};
let device_id = params.device_id.clone();
let (mut ws_sender, mut ws_receiver) = socket.split();
let (tx, mut rx) = mpsc::channel::<WsMessage>(32);
let session_repo = DeviceSessionRepository::new(state.portal.db.pool());
if let Err(e) = session_repo
.upsert(
user_id,
&CreateSessionInput {
device_id: device_id.clone(),
device_name: params.device_name.unwrap_or_else(|| "Unknown".to_string()),
platform: params.platform.unwrap_or_else(|| "web".to_string()),
user_agent: None,
ip_address: None,
},
)
.await
{
error!("Failed to register session: {}", e);
return;
}
state
.ws
.register(user_id, device_id.clone(), tx.clone())
.await;
let _ = tx
.send(WsMessage::Subscribed {
session_id: Uuid::new_v4().to_string(),
})
.await;
let queue_repo = SyncQueueRepository::new(state.portal.db.pool());
if let Ok(pending) = queue_repo.get_pending(&device_id, 100).await {
for item in pending {
if let Ok(event) = serde_json::from_value::<SyncEvent>(item.payload) {
let _ = tx.send(WsMessage::Event(event)).await;
let _ = queue_repo.mark_delivered(item.id).await;
}
}
}
let send_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let json = match serde_json::to_string(&msg) {
Ok(j) => j,
Err(e) => {
error!("Failed to serialize message: {}", e);
continue;
}
};
if ws_sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
});
let state_clone = state.clone();
let device_id_clone = device_id.clone();
let tx_clone = tx.clone();
let recv_task = tokio::spawn(async move {
let ping_interval = Duration::from_secs(state_clone.ws.config.ping_interval_secs);
let mut ping_timer = tokio::time::interval(ping_interval);
loop {
tokio::select! {
msg = ws_receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
if let Err(e) = handle_message(
&state_clone,
user_id,
&device_id_clone,
&text,
&tx_clone,
).await {
error!("Error handling message: {}", e);
let _ = tx_clone.send(WsMessage::Error {
code: "HANDLER_ERROR".to_string(),
message: e.to_string(),
}).await;
}
}
Some(Ok(Message::Ping(data))) => {
debug!("Received ping from {}", device_id_clone);
}
Some(Ok(Message::Pong(_))) => {
debug!("Received pong from {}", device_id_clone);
}
Some(Ok(Message::Close(_))) => {
info!("Client {} closed connection", device_id_clone);
break;
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
break;
}
None => break,
_ => {}
}
}
_ = ping_timer.tick() => {
if tx_clone.send(WsMessage::Ping).await.is_err() {
break;
}
}
}
}
});
tokio::select! {
_ = send_task => {}
_ = recv_task => {}
}
state.ws.unregister(user_id, &device_id).await;
let _ = session_repo.disconnect(user_id, &device_id).await;
info!("WebSocket closed for user {} device {}", user_id, device_id);
}
#[cfg(feature = "portal")]
async fn handle_message(
state: &SyncWsState,
user_id: Uuid,
device_id: &str,
text: &str,
tx: &mpsc::Sender<WsMessage>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let msg: WsMessage = serde_json::from_str(text)?;
match msg {
WsMessage::Subscribe {
device_id: _,
token,
} => {
if state.portal.auth.validate_token(&token).is_err() {
tx.send(WsMessage::Error {
code: "AUTH_FAILED".to_string(),
message: "Invalid token".to_string(),
})
.await?;
}
}
WsMessage::PushChanges { version, changes } => {
let sync_repo = SyncStateRepository::new(state.portal.db.pool());
for change in &changes {
if let Some(value) = &change.value {
let input = UpdateSyncStateInput {
resource_type: "settings".to_string(),
resource_key: change.path.clone(),
value: value.clone(),
device_id: device_id.to_string(),
};
match sync_repo
.update(user_id, &input, Some(version as i64))
.await
{
Ok(new_state) => {
let event = SyncEvent::SettingsChanged {
version: new_state.version as u64,
changes: vec![change.clone()],
device_id: device_id.to_string(),
};
state.ws.broadcast(user_id, event, Some(device_id)).await;
queue_for_offline_devices(
state,
user_id,
device_id,
&SyncEvent::SettingsChanged {
version: new_state.version as u64,
changes: vec![change.clone()],
device_id: device_id.to_string(),
},
)
.await;
}
Err(crate::portal::sync_db::SyncStateError::Conflict {
current_version,
current_value,
last_modified_by,
last_modified_at,
..
}) => {
let conflict_event = SyncEvent::ConflictDetected {
path: change.path.clone(),
local_value: change.value.clone().unwrap_or_default(),
remote_value: current_value,
local_timestamp: change.timestamp,
remote_timestamp: last_modified_at,
};
tx.send(WsMessage::Event(conflict_event)).await?;
}
Err(e) => {
error!("Sync state update failed: {}", e);
return Err(e.into());
}
}
}
}
tx.send(WsMessage::ChangesAccepted {
new_version: version + 1,
})
.await?;
}
WsMessage::Ping => {
tx.send(WsMessage::Pong).await?;
}
WsMessage::Pong => {
debug!("Pong received from {}", device_id);
}
_ => {
}
}
Ok(())
}
#[cfg(feature = "portal")]
async fn queue_for_offline_devices(
state: &SyncWsState,
user_id: Uuid,
exclude_device: &str,
event: &SyncEvent,
) {
let session_repo = DeviceSessionRepository::new(state.portal.db.pool());
let queue_repo = SyncQueueRepository::new(state.portal.db.pool());
let connections = state.ws.connections.read().await;
let connected_devices: std::collections::HashSet<_> = connections
.get(&user_id)
.map(|c| c.keys().cloned().collect())
.unwrap_or_default();
drop(connections);
debug!(
"Event queued for offline devices (except {})",
exclude_device
);
}
#[cfg(feature = "portal")]
pub async fn get_sync_status(
State(state): State<SyncWsState>,
claims: crate::portal::middleware::AuthClaims,
) -> impl IntoResponse {
let session_repo = DeviceSessionRepository::new(state.portal.db.pool());
let user_id = match Uuid::parse_str(&claims.0.sub) {
Ok(id) => id,
Err(_) => {
return (
axum::http::StatusCode::BAD_REQUEST,
axum::Json(serde_json::json!({ "error": "Invalid user ID" })),
)
.into_response()
}
};
match session_repo.get_connected(user_id).await {
Ok(sessions) => {
let response = serde_json::json!({
"connected_devices": sessions.len(),
"devices": sessions.iter().map(|s| {
serde_json::json!({
"device_id": s.device_id,
"device_name": s.device_name,
"platform": s.platform,
"connected": s.connected,
"last_connected_at": s.last_connected_at,
"settings_version": s.settings_version,
})
}).collect::<Vec<_>>(),
});
axum::Json(response).into_response()
}
Err(e) => {
error!("Failed to get sync status: {}", e);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({ "error": "Failed to get sync status" })),
)
.into_response()
}
}
}
#[cfg(feature = "portal")]
pub async fn get_conflicts(
State(state): State<SyncWsState>,
claims: crate::portal::middleware::AuthClaims,
) -> impl IntoResponse {
let conflict_repo = crate::portal::sync_db::SyncConflictRepository::new(state.portal.db.pool());
let user_id = match Uuid::parse_str(&claims.0.sub) {
Ok(id) => id,
Err(_) => {
return (
axum::http::StatusCode::BAD_REQUEST,
axum::Json(serde_json::json!({ "error": "Invalid user ID" })),
)
.into_response()
}
};
match conflict_repo.get_unresolved(user_id).await {
Ok(conflicts) => axum::Json(serde_json::json!({
"conflicts": conflicts,
}))
.into_response(),
Err(e) => {
error!("Failed to get conflicts: {}", e);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({ "error": "Failed to get conflicts" })),
)
.into_response()
}
}
}
#[cfg(feature = "portal")]
#[derive(Debug, Deserialize)]
pub struct ResolveConflictRequest {
pub strategy: String,
pub custom_value: Option<serde_json::Value>,
}
#[cfg(feature = "portal")]
pub async fn resolve_conflict(
State(state): State<SyncWsState>,
claims: crate::portal::middleware::AuthClaims,
axum::extract::Path(conflict_id): axum::extract::Path<Uuid>,
axum::Json(body): axum::Json<ResolveConflictRequest>,
) -> impl IntoResponse {
use crate::portal::sync::ConflictResolution;
let user_id = match Uuid::parse_str(&claims.0.sub) {
Ok(id) => id,
Err(_) => {
return (
axum::http::StatusCode::BAD_REQUEST,
axum::Json(serde_json::json!({ "error": "Invalid user ID" })),
)
.into_response()
}
};
let strategy = match body.strategy.as_str() {
"use_local" => ConflictResolution::UseLocal,
"use_remote" => ConflictResolution::UseRemote,
"use_newest" => ConflictResolution::UseNewest,
"merge" => ConflictResolution::Merge,
"manual" => ConflictResolution::Manual,
_ => {
return (
axum::http::StatusCode::BAD_REQUEST,
axum::Json(serde_json::json!({ "error": "Invalid resolution strategy" })),
)
.into_response()
}
};
let conflict_repo = crate::portal::sync_db::SyncConflictRepository::new(state.portal.db.pool());
match conflict_repo.get_unresolved(user_id).await {
Ok(conflicts) => {
let conflict = conflicts.into_iter().find(|c| c.id == conflict_id);
if let Some(conflict) = conflict {
let resolved_value = match strategy {
ConflictResolution::UseLocal => conflict.local_value.clone(),
ConflictResolution::UseRemote => conflict.remote_value.clone(),
ConflictResolution::UseNewest => {
if conflict.local_timestamp > conflict.remote_timestamp {
conflict.local_value.clone()
} else {
conflict.remote_value.clone()
}
}
ConflictResolution::Manual => body
.custom_value
.clone()
.unwrap_or(conflict.local_value.clone()),
ConflictResolution::Merge => {
if let (Some(local_obj), Some(remote_obj)) = (
conflict.local_value.as_object(),
conflict.remote_value.as_object(),
) {
let mut merged = local_obj.clone();
for (k, v) in remote_obj {
merged.insert(k.clone(), v.clone());
}
serde_json::Value::Object(merged)
} else {
conflict.remote_value.clone()
}
}
};
if let Err(e) = conflict_repo
.resolve(conflict_id, strategy, &resolved_value, "user")
.await
{
error!("Failed to resolve conflict: {}", e);
return (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({ "error": "Failed to resolve conflict" })),
)
.into_response();
}
let sync_repo = SyncStateRepository::new(state.portal.db.pool());
let _ = sync_repo
.update(
user_id,
&UpdateSyncStateInput {
resource_type: conflict.resource_type,
resource_key: conflict.resource_key,
value: resolved_value.clone(),
device_id: "conflict_resolution".to_string(),
},
None,
)
.await;
axum::Json(serde_json::json!({
"resolved": true,
"value": resolved_value,
}))
.into_response()
} else {
(
axum::http::StatusCode::NOT_FOUND,
axum::Json(serde_json::json!({ "error": "Conflict not found" })),
)
.into_response()
}
}
Err(e) => {
error!("Failed to get conflict: {}", e);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({ "error": "Failed to get conflict" })),
)
.into_response()
}
}
}
#[cfg(all(test, feature = "portal"))]
mod tests {
use super::*;
#[tokio::test]
async fn test_ws_state_register_unregister() {
let state = WsState::new(SyncConfig::default());
let user_id = Uuid::new_v4();
let device_id = "test-device".to_string();
let (tx, _rx) = mpsc::channel(32);
state.register(user_id, device_id.clone(), tx).await;
assert_eq!(state.connected_count(user_id).await, 1);
state.unregister(user_id, &device_id).await;
assert_eq!(state.connected_count(user_id).await, 0);
}
#[tokio::test]
async fn test_ws_state_multiple_devices() {
let state = WsState::new(SyncConfig::default());
let user_id = Uuid::new_v4();
let (tx1, _rx1) = mpsc::channel(32);
let (tx2, _rx2) = mpsc::channel(32);
state.register(user_id, "device-1".to_string(), tx1).await;
state.register(user_id, "device-2".to_string(), tx2).await;
assert_eq!(state.connected_count(user_id).await, 2);
}
}