1use crate::service::config::{ServiceConfig, ServiceKind};
4use anyhow::{Context, Result, bail};
5use std::{
6 collections::BTreeMap,
7 path::{Path, PathBuf},
8 sync::Arc,
9};
10use tokio::{
11 net::unix::{OwnedReadHalf, OwnedWriteHalf},
12 process::Child,
13 sync::Mutex,
14 time,
15};
16use wcore::{
17 ToolRegistry,
18 model::Tool,
19 protocol::{
20 PROTOCOL_VERSION,
21 codec::{read_message, write_message},
22 ext::{
23 Capability, ExtConfigure, ExtConfigured, ExtError, ExtHello, ExtReady,
24 ExtRegisterTools, ExtRequest, ExtResponse, ExtToolCall, ExtToolResult, ExtToolSchemas,
25 ToolsList, capability, ext_request, ext_response,
26 },
27 },
28};
29
30pub struct ServiceHandle {
32 pub name: String,
33 pub capabilities: Vec<Capability>,
34 writer: Mutex<OwnedWriteHalf>,
35 reader: Mutex<OwnedReadHalf>,
36 rpc_lock: Mutex<()>,
38}
39
40impl ServiceHandle {
41 pub async fn request(&self, req: &ExtRequest) -> Result<ExtResponse> {
43 let _guard = self.rpc_lock.lock().await;
44 let mut w = self.writer.lock().await;
45 write_message(&mut *w, req).await.context("ext write")?;
46 drop(w);
47 let mut r = self.reader.lock().await;
48 let resp: ExtResponse = read_message(&mut *r).await.context("ext read")?;
49 Ok(resp)
50 }
51
52 pub async fn send(&self, req: &ExtRequest) -> Result<()> {
54 let _guard = self.rpc_lock.lock().await;
55 let mut w = self.writer.lock().await;
56 write_message(&mut *w, req).await.context("ext write")?;
57 Ok(())
58 }
59}
60
61#[derive(Default)]
63pub struct ServiceRegistry {
64 pub tools: BTreeMap<String, Arc<ServiceHandle>>,
66 pub query: BTreeMap<String, Arc<ServiceHandle>>,
68 pub tool_schemas: Vec<Tool>,
70}
71
72impl ServiceRegistry {
73 pub async fn dispatch_tool(
76 &self,
77 name: &str,
78 args: &str,
79 agent: &str,
80 task_id: Option<u64>,
81 ) -> Option<String> {
82 let handle = self.tools.get(name)?;
83 let req = ExtRequest {
84 msg: Some(ext_request::Msg::ToolCall(ExtToolCall {
85 name: name.to_owned(),
86 args: args.to_owned(),
87 agent: agent.to_owned(),
88 task_id,
89 })),
90 };
91 Some(
92 match time::timeout(std::time::Duration::from_secs(30), handle.request(&req)).await {
93 Ok(Ok(resp)) => match resp.msg {
94 Some(ext_response::Msg::ToolResult(ExtToolResult { result })) => result,
95 Some(ext_response::Msg::Error(ExtError { message })) => {
96 format!("service error: {message}")
97 }
98 other => format!("unexpected response: {other:?}"),
99 },
100 Ok(Err(e)) => format!("service unavailable: {name} ({e})"),
101 Err(_) => format!("service timeout: {name}"),
102 },
103 )
104 }
105
106 pub async fn register_tools(&self, tools: &mut ToolRegistry) {
108 tools.insert_all(self.tool_schemas.clone());
109 }
110}
111
112struct ServiceEntry {
114 config: ServiceConfig,
115 child: Option<Child>,
116 socket_path: PathBuf,
117}
118
119pub struct ServiceManager {
121 entries: BTreeMap<String, ServiceEntry>,
122 services_dir: PathBuf,
123 daemon_socket: PathBuf,
125}
126
127const HANDSHAKE_TIMEOUT: time::Duration = time::Duration::from_secs(10);
128
129impl ServiceManager {
130 pub fn new(
135 configs: &BTreeMap<String, ServiceConfig>,
136 config_dir: &Path,
137 daemon_socket: PathBuf,
138 ) -> Self {
139 let services_dir = config_dir.join("services");
140 let entries = configs
141 .iter()
142 .filter(|(_, c)| c.enabled)
143 .map(|(name, config)| {
144 let socket_path = services_dir.join(format!("{name}.sock"));
145 (
146 name.clone(),
147 ServiceEntry {
148 config: config.clone(),
149 child: None,
150 socket_path,
151 },
152 )
153 })
154 .collect();
155 Self {
156 entries,
157 services_dir,
158 daemon_socket,
159 }
160 }
161
162 pub async fn spawn_all(&mut self) -> Result<()> {
168 std::fs::create_dir_all(&self.services_dir).context("create services dir")?;
169 let logs_dir = &*wcore::paths::LOGS_DIR;
170 std::fs::create_dir_all(logs_dir).context("create logs dir")?;
171
172 for (name, entry) in &mut self.entries {
173 if entry.socket_path.exists() {
175 let _ = std::fs::remove_file(&entry.socket_path);
176 }
177
178 let cargo_bin = std::env::var("HOME").ok().map(|h| {
181 PathBuf::from(h)
182 .join(".cargo/bin")
183 .join(&entry.config.krate)
184 });
185 let binary = match cargo_bin {
186 Some(ref p) if p.exists() => p.as_path(),
187 _ => Path::new(&entry.config.krate),
188 };
189 tracing::info!(
190 service = %name,
191 binary = %binary.display(),
192 kind = ?entry.config.kind,
193 "spawning service"
194 );
195 let mut cmd = tokio::process::Command::new(binary);
196 for (k, v) in &entry.config.env {
197 cmd.env(k, v);
198 }
199
200 if !entry.config.env.contains_key("RUST_LOG")
202 && let Ok(rust_log) = std::env::var("RUST_LOG")
203 {
204 cmd.env("RUST_LOG", rust_log);
205 }
206
207 let log_path = logs_dir.join(format!("{name}.log"));
209 let log_file = std::fs::File::create(&log_path)
210 .with_context(|| format!("create log file for '{name}'"))?;
211 cmd.stdout(log_file.try_clone()?);
212 cmd.stderr(log_file);
213
214 cmd.arg("serve");
215 match entry.config.kind {
216 ServiceKind::Extension => {
217 cmd.arg("--socket").arg(&entry.socket_path);
218 }
219 ServiceKind::Gateway => {
220 cmd.arg("--daemon").arg(&self.daemon_socket);
221 let config_json = serde_json::to_string(&entry.config.config)
222 .unwrap_or_else(|_| "{}".to_owned());
223 cmd.arg("--config").arg(config_json);
224 }
225 }
226
227 cmd.kill_on_drop(true);
228 let child = cmd.spawn().with_context(|| {
229 format!("spawn service '{name}' (binary: {})", binary.display())
230 })?;
231 tracing::info!(service = %name, pid = child.id(), log = %log_path.display(), "spawned service");
232 entry.child = Some(child);
233 }
234
235 Ok(())
236 }
237
238 pub async fn handshake_all(&self) -> ServiceRegistry {
241 let mut registry = ServiceRegistry::default();
242
243 for (name, entry) in &self.entries {
244 if !matches!(entry.config.kind, ServiceKind::Extension) {
245 continue;
246 }
247
248 match self
249 .handshake_one(name, &entry.socket_path, &entry.config.config)
250 .await
251 {
252 Ok((handle, schemas)) => {
253 let handle = Arc::new(handle);
254 Self::register(&mut registry, &handle);
255 tracing::info!(
256 service = %name,
257 tools = schemas.len(),
258 "extension registered"
259 );
260 registry.tool_schemas.extend(schemas);
261 }
262 Err(e) => {
263 tracing::warn!(service = %name, error = %e, "extension handshake failed, skipping");
264 }
265 }
266 }
267
268 registry
269 }
270
271 async fn handshake_one(
274 &self,
275 name: &str,
276 socket_path: &Path,
277 config: &serde_json::Value,
278 ) -> Result<(ServiceHandle, Vec<Tool>)> {
279 let deadline = time::Instant::now() + HANDSHAKE_TIMEOUT;
281 loop {
282 if socket_path.exists() {
283 break;
284 }
285 if time::Instant::now() >= deadline {
286 bail!(
287 "socket not found after {}s: {}",
288 HANDSHAKE_TIMEOUT.as_secs(),
289 socket_path.display()
290 );
291 }
292 time::sleep(time::Duration::from_millis(50)).await;
293 }
294
295 let stream = time::timeout(
296 HANDSHAKE_TIMEOUT,
297 tokio::net::UnixStream::connect(socket_path),
298 )
299 .await
300 .context("connect timeout")?
301 .context("connect")?;
302
303 let (read_half, write_half) = stream.into_split();
304 let writer = Mutex::new(write_half);
305 let reader = Mutex::new(read_half);
306
307 let hello = ExtRequest {
309 msg: Some(ext_request::Msg::Hello(ExtHello {
310 version: PROTOCOL_VERSION.to_owned(),
311 })),
312 };
313 {
314 let mut w = writer.lock().await;
315 write_message(&mut *w, &hello)
316 .await
317 .context("write Hello")?;
318 }
319 let ready: ExtResponse = {
320 let mut r = reader.lock().await;
321 time::timeout(HANDSHAKE_TIMEOUT, read_message(&mut *r))
322 .await
323 .context("Ready timeout")?
324 .context("read Ready")?
325 };
326 let (service, capabilities) = match ready.msg {
327 Some(ext_response::Msg::Ready(ExtReady {
328 service,
329 capabilities,
330 ..
331 })) => (service, capabilities),
332 Some(ext_response::Msg::Error(ExtError { message })) => {
333 bail!("service error: {message}")
334 }
335 other => bail!("unexpected response to Hello: {other:?}"),
336 };
337 tracing::debug!(service = %service, "handshake Hello/Ready complete");
338
339 let handle = ServiceHandle {
340 name: service,
341 capabilities,
342 writer,
343 reader,
344 rpc_lock: Mutex::new(()),
345 };
346
347 let config_json = serde_json::to_string(config).context("serialize service config")?;
349 let configure_req = ExtRequest {
350 msg: Some(ext_request::Msg::Configure(ExtConfigure {
351 config: config_json,
352 })),
353 };
354 let configure_resp = time::timeout(HANDSHAKE_TIMEOUT, handle.request(&configure_req))
355 .await
356 .context("Configure timeout")?
357 .context("Configure")?;
358 match configure_resp.msg {
359 Some(ext_response::Msg::Configured(ExtConfigured {})) => {}
360 Some(ext_response::Msg::Error(ExtError { message })) => {
361 bail!("Configure error: {message}")
362 }
363 other => bail!("unexpected response to Configure: {other:?}"),
364 }
365 tracing::debug!(service = %name, "handshake Configure/Configured complete");
366
367 let register_tools_req = ExtRequest {
369 msg: Some(ext_request::Msg::RegisterTools(ExtRegisterTools {})),
370 };
371 let resp = time::timeout(HANDSHAKE_TIMEOUT, handle.request(®ister_tools_req))
372 .await
373 .context("RegisterTools timeout")?
374 .context("RegisterTools")?;
375 let tool_defs = match resp.msg {
376 Some(ext_response::Msg::ToolSchemas(ExtToolSchemas { tools })) => tools,
377 Some(ext_response::Msg::Error(ExtError { message })) => {
378 bail!("RegisterTools error: {message}")
379 }
380 other => bail!("unexpected response to RegisterTools: {other:?}"),
381 };
382 tracing::debug!(service = %name, tools = tool_defs.len(), "handshake RegisterTools/ToolSchemas complete");
383
384 let tools: Vec<Tool> = tool_defs
386 .into_iter()
387 .map(|td| Tool {
388 name: td.name.to_string(),
389 description: td.description.to_string(),
390 parameters: serde_json::from_slice(&td.parameters).unwrap_or_else(|_| true.into()),
391 strict: td.strict,
392 })
393 .collect();
394
395 Ok((handle, tools))
396 }
397
398 fn register(registry: &mut ServiceRegistry, handle: &Arc<ServiceHandle>) {
400 for cap in &handle.capabilities {
401 match &cap.cap {
402 Some(capability::Cap::Tools(ToolsList { names })) => {
403 for tool_name in names {
404 registry.tools.insert(tool_name.clone(), Arc::clone(handle));
405 }
406 }
407 Some(capability::Cap::Query(_)) => {
408 registry
409 .query
410 .insert(handle.name.to_string(), Arc::clone(handle));
411 }
412 _ => {}
413 }
414 }
415 }
416
417 pub async fn shutdown_all(&mut self) {
420 for (name, entry) in &mut self.entries {
422 if let Some(ref mut child) = entry.child {
423 tracing::debug!(service = %name, pid = child.id(), "stopping service");
424 let _ = child.start_kill();
425 }
426 }
427
428 for (name, entry) in &mut self.entries {
430 if let Some(ref mut child) = entry.child {
431 match time::timeout(time::Duration::from_secs(5), child.wait()).await {
432 Ok(Ok(status)) => {
433 tracing::debug!(service = %name, %status, "service exited");
434 }
435 Ok(Err(e)) => {
436 tracing::warn!(service = %name, error = %e, "error waiting for service");
437 }
438 Err(_) => {
439 tracing::warn!(service = %name, "service did not exit in 5s, killing");
440 let _ = child.kill().await;
441 }
442 }
443 }
444 let _ = std::fs::remove_file(&entry.socket_path);
445 }
446 }
447}