1use crate::handlers::{
6 handle_context, handle_discover, handle_impact, handle_info, handle_node_get, handle_search,
7 SharedGraph,
8};
9use crate::protocol::{
10 ContextParams, DiscoverParams, ImpactParams, NodeGetParams, Request, Response, SearchParams,
11};
12use arbor_graph::ArborGraph;
13use futures_util::{SinkExt, StreamExt};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::RwLock;
18use tokio_tungstenite::{accept_async, tungstenite::Message};
19use tracing::{debug, error, info, warn};
20
21pub struct ServerConfig {
23 pub addr: SocketAddr,
25}
26
27impl Default for ServerConfig {
28 fn default() -> Self {
29 Self {
30 addr: "127.0.0.1:7432".parse().unwrap(),
31 }
32 }
33}
34
35pub struct ArborServer {
37 config: ServerConfig,
38 graph: SharedGraph,
39}
40
41impl ArborServer {
42 pub fn new(graph: ArborGraph, config: ServerConfig) -> Self {
44 Self {
45 config,
46 graph: Arc::new(RwLock::new(graph)),
47 }
48 }
49
50 pub fn new_with_shared(graph: SharedGraph, config: ServerConfig) -> Self {
52 Self { config, graph }
53 }
54
55 pub fn graph(&self) -> SharedGraph {
57 self.graph.clone()
58 }
59
60 pub async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
62 let listener = TcpListener::bind(&self.config.addr).await?;
63 info!("Arbor server listening on {}", self.config.addr);
64
65 loop {
66 match listener.accept().await {
67 Ok((stream, addr)) => {
68 debug!("New connection from {}", addr);
69 let graph = self.graph.clone();
70 tokio::spawn(async move {
71 if let Err(e) = handle_connection(stream, addr, graph).await {
72 error!("Connection error from {}: {}", addr, e);
73 }
74 });
75 }
76 Err(e) => {
77 error!("Accept error: {}", e);
78 }
79 }
80 }
81 }
82}
83
84async fn handle_connection(
86 stream: TcpStream,
87 addr: SocketAddr,
88 graph: SharedGraph,
89) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
90 let ws_stream = accept_async(stream).await?;
91 info!("WebSocket connection established with {}", addr);
92
93 let (mut write, mut read) = ws_stream.split();
94
95 while let Some(msg) = read.next().await {
96 let msg = match msg {
97 Ok(m) => m,
98 Err(e) => {
99 warn!("Message error from {}: {}", addr, e);
100 break;
101 }
102 };
103
104 if msg.is_close() {
105 debug!("Client {} disconnected", addr);
106 break;
107 }
108
109 if msg.is_ping() {
110 write.send(Message::Pong(msg.into_data())).await?;
111 continue;
112 }
113
114 if msg.is_text() {
115 let text = msg.to_text().unwrap_or("");
116 let response = process_message(text, graph.clone()).await;
117 let json = serde_json::to_string(&response)?;
118 write.send(Message::Text(json)).await?;
119 }
120 }
121
122 info!("Connection closed: {}", addr);
123 Ok(())
124}
125
126async fn process_message(text: &str, graph: SharedGraph) -> Response {
128 let request: Request = match serde_json::from_str(text) {
130 Ok(r) => r,
131 Err(_) => return Response::parse_error(),
132 };
133
134 let id = request.id.clone();
135 let method = request.method.as_str();
136
137 debug!("Processing method: {}", method);
138
139 match method {
141 "graph.info" => handle_info(graph, id).await,
142
143 "discover" => match serde_json::from_value::<DiscoverParams>(request.params) {
144 Ok(params) => handle_discover(graph, id, params).await,
145 Err(e) => Response::invalid_params(id, e.to_string()),
146 },
147
148 "impact" => match serde_json::from_value::<ImpactParams>(request.params) {
149 Ok(params) => handle_impact(graph, id, params).await,
150 Err(e) => Response::invalid_params(id, e.to_string()),
151 },
152
153 "context" => match serde_json::from_value::<ContextParams>(request.params) {
154 Ok(params) => handle_context(graph, id, params).await,
155 Err(e) => Response::invalid_params(id, e.to_string()),
156 },
157
158 "search" => match serde_json::from_value::<SearchParams>(request.params) {
159 Ok(params) => handle_search(graph, id, params).await,
160 Err(e) => Response::invalid_params(id, e.to_string()),
161 },
162
163 "node.get" => match serde_json::from_value::<NodeGetParams>(request.params) {
164 Ok(params) => handle_node_get(graph, id, params).await,
165 Err(e) => Response::invalid_params(id, e.to_string()),
166 },
167
168 _ => Response::method_not_found(id, method),
169 }
170}