Skip to main content

arbor_server/
server.rs

1//! WebSocket server implementation.
2//!
3//! Handles client connections and routes messages to handlers.
4
5use crate::handlers::{
6    handle_context, handle_discover, handle_impact, handle_info, handle_node_get, handle_search,
7    SharedGraph,
8};
9use crate::protocol::{
10    ContextParams, DiscoverParams, ImpactParams, NodeGetParams, Request, Response, SearchParams,
11};
12use arbor_graph::ArborGraph;
13use futures_util::{SinkExt, StreamExt};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::RwLock;
18use tokio_tungstenite::{accept_async, tungstenite::Message};
19use tracing::{debug, error, info, warn};
20
21/// Server configuration.
22pub struct ServerConfig {
23    /// Address to bind to.
24    pub addr: SocketAddr,
25}
26
27impl Default for ServerConfig {
28    fn default() -> Self {
29        Self {
30            addr: "127.0.0.1:7432".parse().unwrap(),
31        }
32    }
33}
34
35/// The Arbor WebSocket server.
36pub struct ArborServer {
37    config: ServerConfig,
38    graph: SharedGraph,
39}
40
41impl ArborServer {
42    /// Creates a new server with the given graph.
43    pub fn new(graph: ArborGraph, config: ServerConfig) -> Self {
44        Self {
45            config,
46            graph: Arc::new(RwLock::new(graph)),
47        }
48    }
49
50    /// Creates a new server with an existing shared graph handle.
51    pub fn new_with_shared(graph: SharedGraph, config: ServerConfig) -> Self {
52        Self { config, graph }
53    }
54
55    /// Returns a handle to the shared graph for updates.
56    pub fn graph(&self) -> SharedGraph {
57        self.graph.clone()
58    }
59
60    /// Runs the server, accepting connections forever.
61    pub async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
62        let listener = TcpListener::bind(&self.config.addr).await?;
63        info!("Arbor server listening on {}", self.config.addr);
64
65        loop {
66            match listener.accept().await {
67                Ok((stream, addr)) => {
68                    debug!("New connection from {}", addr);
69                    let graph = self.graph.clone();
70                    tokio::spawn(async move {
71                        if let Err(e) = handle_connection(stream, addr, graph).await {
72                            error!("Connection error from {}: {}", addr, e);
73                        }
74                    });
75                }
76                Err(e) => {
77                    error!("Accept error: {}", e);
78                }
79            }
80        }
81    }
82}
83
84/// Handles a single WebSocket connection.
85async fn handle_connection(
86    stream: TcpStream,
87    addr: SocketAddr,
88    graph: SharedGraph,
89) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
90    let ws_stream = accept_async(stream).await?;
91    info!("WebSocket connection established with {}", addr);
92
93    let (mut write, mut read) = ws_stream.split();
94
95    while let Some(msg) = read.next().await {
96        let msg = match msg {
97            Ok(m) => m,
98            Err(e) => {
99                warn!("Message error from {}: {}", addr, e);
100                break;
101            }
102        };
103
104        if msg.is_close() {
105            debug!("Client {} disconnected", addr);
106            break;
107        }
108
109        if msg.is_ping() {
110            write.send(Message::Pong(msg.into_data())).await?;
111            continue;
112        }
113
114        if msg.is_text() {
115            let text = msg.to_text().unwrap_or("");
116            let response = process_message(text, graph.clone()).await;
117            let json = serde_json::to_string(&response)?;
118            write.send(Message::Text(json)).await?;
119        }
120    }
121
122    info!("Connection closed: {}", addr);
123    Ok(())
124}
125
126/// Processes a JSON-RPC message and returns a response.
127async fn process_message(text: &str, graph: SharedGraph) -> Response {
128    // Parse the request
129    let request: Request = match serde_json::from_str(text) {
130        Ok(r) => r,
131        Err(_) => return Response::parse_error(),
132    };
133
134    let id = request.id.clone();
135    let method = request.method.as_str();
136
137    debug!("Processing method: {}", method);
138
139    // Route to handler
140    match method {
141        "graph.info" => handle_info(graph, id).await,
142
143        "discover" => match serde_json::from_value::<DiscoverParams>(request.params) {
144            Ok(params) => handle_discover(graph, id, params).await,
145            Err(e) => Response::invalid_params(id, e.to_string()),
146        },
147
148        "impact" => match serde_json::from_value::<ImpactParams>(request.params) {
149            Ok(params) => handle_impact(graph, id, params).await,
150            Err(e) => Response::invalid_params(id, e.to_string()),
151        },
152
153        "context" => match serde_json::from_value::<ContextParams>(request.params) {
154            Ok(params) => handle_context(graph, id, params).await,
155            Err(e) => Response::invalid_params(id, e.to_string()),
156        },
157
158        "search" => match serde_json::from_value::<SearchParams>(request.params) {
159            Ok(params) => handle_search(graph, id, params).await,
160            Err(e) => Response::invalid_params(id, e.to_string()),
161        },
162
163        "node.get" => match serde_json::from_value::<NodeGetParams>(request.params) {
164            Ok(params) => handle_node_get(graph, id, params).await,
165            Err(e) => Response::invalid_params(id, e.to_string()),
166        },
167
168        _ => Response::method_not_found(id, method),
169    }
170}