1use std::io::{BufRead, BufReader, Read};
7use std::time::Duration;
8
9use hickory_proto::rr::RecordType;
10use koi_common::mdns_protocol::{
11 AdminRegistration, DaemonStatus, RegisterPayload, RegistrationResult, RenewalResult,
12};
13use koi_common::net::resolve_localhost;
14use koi_common::types::{ServiceCheckKind, ServiceRecord};
15
16const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
18
19const READ_TIMEOUT: Duration = Duration::from_secs(10);
21
22const HEALTH_TIMEOUT: Duration = Duration::from_millis(200);
24
25#[derive(Debug, thiserror::Error)]
28pub enum ClientError {
29 #[error("Daemon not reachable: {0}")]
30 Unreachable(String),
31
32 #[error("remote daemon requires a token (pass --token or set KOI_TOKEN)")]
36 Unauthorized,
37
38 #[error("{error}: {message}")]
39 Api { error: String, message: String },
40
41 #[error("Request failed: {0}")]
42 Transport(String),
43
44 #[error("Invalid response: {0}")]
45 Decode(String),
46}
47
48impl ClientError {
49 pub fn is_unauthorized(&self) -> bool {
51 matches!(self, ClientError::Unauthorized)
52 }
53}
54
55pub type Result<T> = std::result::Result<T, ClientError>;
56
57const DAT_HEADER: &str = "X-Koi-Token";
61
62pub struct KoiClient {
63 endpoint: String,
64 agent: ureq::Agent,
65 token: String,
67}
68
69impl KoiClient {
70 pub fn new(endpoint: &str) -> Self {
71 let clean = endpoint.trim_end_matches('/');
72 let resolved = resolve_localhost(clean);
73 let agent = ureq::AgentBuilder::new()
74 .timeout_connect(CONNECT_TIMEOUT)
75 .timeout_read(READ_TIMEOUT)
76 .build();
77 Self {
78 endpoint: resolved,
79 agent,
80 token: String::new(),
81 }
82 }
83
84 pub fn with_token(endpoint: &str, token: &str) -> Self {
86 let mut client = Self::new(endpoint);
87 client.token = token.to_string();
88 client
89 }
90
91 pub fn from_breadcrumb() -> Option<Self> {
95 let bc = koi_config::breadcrumb::read_breadcrumb()?;
96 Some(Self::with_token(&bc.endpoint, &bc.token))
97 }
98
99 fn auth_get(&self, url: &str) -> ureq::Request {
101 let req = self.agent.get(url);
102 if self.token.is_empty() {
103 req
104 } else {
105 req.set(DAT_HEADER, &self.token)
106 }
107 }
108
109 fn auth_post(&self, url: &str) -> ureq::Request {
111 let req = self.agent.post(url);
112 if self.token.is_empty() {
113 req
114 } else {
115 req.set(DAT_HEADER, &self.token)
116 }
117 }
118
119 fn auth_put(&self, url: &str) -> ureq::Request {
121 let req = self.agent.put(url);
122 if self.token.is_empty() {
123 req
124 } else {
125 req.set(DAT_HEADER, &self.token)
126 }
127 }
128
129 fn auth_delete(&self, url: &str) -> ureq::Request {
131 let req = self.agent.delete(url);
132 if self.token.is_empty() {
133 req
134 } else {
135 req.set(DAT_HEADER, &self.token)
136 }
137 }
138
139 pub fn health(&self) -> Result<()> {
143 let agent = ureq::AgentBuilder::new()
144 .timeout_connect(HEALTH_TIMEOUT)
145 .timeout_read(HEALTH_TIMEOUT)
146 .build();
147 let url = format!("{}/healthz", self.endpoint);
148 agent.get(&url).call().map_err(map_error)?;
149 Ok(())
150 }
151
152 pub fn register(&self, payload: &RegisterPayload) -> Result<RegistrationResult> {
155 let url = format!("{}/v1/mdns/announce", self.endpoint);
156 let json_val =
157 serde_json::to_value(payload).map_err(|e| ClientError::Decode(e.to_string()))?;
158 let resp = self
159 .auth_post(&url)
160 .send_json(json_val)
161 .map_err(map_error)?;
162 let json: serde_json::Value = resp
163 .into_json()
164 .map_err(|e| ClientError::Decode(e.to_string()))?;
165 extract(&json, "registered")
166 }
167
168 pub fn unregister(&self, id: &str) -> Result<()> {
169 let url = format!("{}/v1/mdns/unregister/{id}", self.endpoint);
170 self.auth_delete(&url).call().map_err(map_error)?;
171 Ok(())
172 }
173
174 pub fn heartbeat(&self, id: &str) -> Result<RenewalResult> {
175 let url = format!("{}/v1/mdns/heartbeat/{id}", self.endpoint);
176 let resp = self.auth_put(&url).send_bytes(&[]).map_err(map_error)?;
177 let json: serde_json::Value = resp
178 .into_json()
179 .map_err(|e| ClientError::Decode(e.to_string()))?;
180 extract(&json, "renewed")
181 }
182
183 pub fn resolve(&self, instance: &str) -> Result<ServiceRecord> {
184 let url = format!("{}/v1/mdns/resolve", self.endpoint);
185 let resp = self
186 .auth_get(&url)
187 .query("name", instance)
188 .call()
189 .map_err(map_error)?;
190 let json: serde_json::Value = resp
191 .into_json()
192 .map_err(|e| ClientError::Decode(e.to_string()))?;
193 extract(&json, "resolved")
194 }
195
196 pub fn browse_stream(&self, service_type: &str) -> Result<SseStream> {
198 let url = format!("{}/v1/mdns/discover", self.endpoint);
199 let mut req = self.stream_agent().get(&url);
200 if !self.token.is_empty() {
201 req = req.set(DAT_HEADER, &self.token);
202 }
203 let resp = req.query("type", service_type).call().map_err(map_error)?;
204 Ok(SseStream::new(Box::new(resp.into_reader())))
205 }
206
207 pub fn events_stream(&self, service_type: &str) -> Result<SseStream> {
209 let url = format!("{}/v1/mdns/subscribe", self.endpoint);
210 let mut req = self.stream_agent().get(&url);
211 if !self.token.is_empty() {
212 req = req.set(DAT_HEADER, &self.token);
213 }
214 let resp = req.query("type", service_type).call().map_err(map_error)?;
215 Ok(SseStream::new(Box::new(resp.into_reader())))
216 }
217
218 pub fn unified_status(&self) -> Result<serde_json::Value> {
222 let url = format!("{}/v1/status", self.endpoint);
223 let resp = self.auth_get(&url).call().map_err(map_error)?;
224 resp.into_json()
225 .map_err(|e| ClientError::Decode(e.to_string()))
226 }
227
228 pub fn dns_status(&self) -> Result<serde_json::Value> {
231 self.get_json("/v1/dns/status")
232 }
233
234 pub fn dns_lookup(&self, name: &str, record_type: RecordType) -> Result<serde_json::Value> {
235 let url = format!("{}/v1/dns/lookup", self.endpoint);
236 let resp = self
237 .auth_get(&url)
238 .query("name", name)
239 .query("type", record_type_str(record_type))
240 .call()
241 .map_err(map_error)?;
242 resp.into_json()
243 .map_err(|e| ClientError::Decode(e.to_string()))
244 }
245
246 pub fn dns_list(&self) -> Result<serde_json::Value> {
247 self.get_json("/v1/dns/list")
248 }
249
250 pub fn dns_add(&self, name: &str, ip: &str, ttl: Option<u32>) -> Result<serde_json::Value> {
251 let body = serde_json::json!({
252 "name": name,
253 "ip": ip,
254 "ttl": ttl,
255 });
256 self.post_json("/v1/dns/add", &body)
257 }
258
259 pub fn dns_remove(&self, name: &str) -> Result<serde_json::Value> {
260 let url = format!("{}/v1/dns/remove/{}", self.endpoint, name);
261 let resp = self.auth_delete(&url).call().map_err(map_error)?;
262 resp.into_json()
263 .map_err(|e| ClientError::Decode(e.to_string()))
264 }
265
266 pub fn dns_start(&self) -> Result<serde_json::Value> {
267 self.post_json("/v1/dns/serve", &serde_json::json!({}))
268 }
269
270 pub fn dns_stop(&self) -> Result<serde_json::Value> {
271 self.post_json("/v1/dns/stop", &serde_json::json!({}))
272 }
273
274 pub fn health_status(&self) -> Result<serde_json::Value> {
277 self.get_json("/v1/health/status")
278 }
279
280 pub fn health_add_check(
281 &self,
282 name: &str,
283 kind: ServiceCheckKind,
284 target: &str,
285 interval_secs: u64,
286 timeout_secs: u64,
287 ) -> Result<serde_json::Value> {
288 let body = serde_json::json!({
289 "name": name,
290 "kind": check_kind_str(kind),
291 "target": target,
292 "interval_secs": interval_secs,
293 "timeout_secs": timeout_secs,
294 });
295 self.post_json("/v1/health/add", &body)
296 }
297
298 pub fn health_remove_check(&self, name: &str) -> Result<serde_json::Value> {
299 let url = format!("{}/v1/health/remove/{}", self.endpoint, name);
300 let resp = self.auth_delete(&url).call().map_err(map_error)?;
301 resp.into_json()
302 .map_err(|e| ClientError::Decode(e.to_string()))
303 }
304
305 pub fn proxy_status(&self) -> Result<serde_json::Value> {
308 self.get_json("/v1/proxy/status")
309 }
310
311 pub fn proxy_list(&self) -> Result<serde_json::Value> {
312 self.get_json("/v1/proxy/list")
313 }
314
315 pub fn proxy_add(
316 &self,
317 name: &str,
318 listen_port: u16,
319 backend: &str,
320 allow_remote: bool,
321 ) -> Result<serde_json::Value> {
322 let body = serde_json::json!({
323 "name": name,
324 "listen_port": listen_port,
325 "backend": backend,
326 "allow_remote": allow_remote,
327 });
328 self.post_json("/v1/proxy/add", &body)
329 }
330
331 pub fn proxy_remove(&self, name: &str) -> Result<serde_json::Value> {
332 let url = format!("{}/v1/proxy/remove/{}", self.endpoint, name);
333 let resp = self.auth_delete(&url).call().map_err(map_error)?;
334 resp.into_json()
335 .map_err(|e| ClientError::Decode(e.to_string()))
336 }
337
338 pub fn udp_status(&self) -> Result<serde_json::Value> {
341 self.get_json("/v1/udp/status")
342 }
343
344 pub fn udp_bind(
345 &self,
346 port: u16,
347 addr: &str,
348 lease_secs: u64,
349 allow_remote: bool,
350 ) -> Result<serde_json::Value> {
351 let body = serde_json::json!({
352 "port": port,
353 "addr": addr,
354 "lease_secs": lease_secs,
355 "allow_remote": allow_remote,
356 });
357 self.post_json("/v1/udp/bind", &body)
358 }
359
360 pub fn udp_unbind(&self, id: &str) -> Result<serde_json::Value> {
361 let url = format!("{}/v1/udp/bind/{}", self.endpoint, id);
362 let resp = self.auth_delete(&url).call().map_err(map_error)?;
363 resp.into_json()
364 .map_err(|e| ClientError::Decode(e.to_string()))
365 }
366
367 pub fn udp_send(&self, id: &str, dest: &str, payload_b64: &str) -> Result<serde_json::Value> {
368 let body = serde_json::json!({
369 "dest": dest,
370 "payload": payload_b64,
371 });
372 let path = format!("/v1/udp/send/{id}");
373 self.post_json(&path, &body)
374 }
375
376 pub fn udp_heartbeat(&self, id: &str) -> Result<serde_json::Value> {
377 let path = format!("/v1/udp/heartbeat/{id}");
378 self.put_json(&path, &serde_json::json!({}))
379 }
380
381 pub fn post_json(&self, path: &str, body: &serde_json::Value) -> Result<serde_json::Value> {
385 let url = format!("{}{path}", self.endpoint);
386 let resp = self
387 .auth_post(&url)
388 .send_json(body.clone())
389 .map_err(map_error)?;
390 resp.into_json()
391 .map_err(|e| ClientError::Decode(e.to_string()))
392 }
393
394 pub fn get_json(&self, path: &str) -> Result<serde_json::Value> {
396 let url = format!("{}{path}", self.endpoint);
397 let resp = self.auth_get(&url).call().map_err(map_error)?;
398 resp.into_json()
399 .map_err(|e| ClientError::Decode(e.to_string()))
400 }
401
402 pub fn put_json(&self, path: &str, body: &serde_json::Value) -> Result<serde_json::Value> {
404 let url = format!("{}{path}", self.endpoint);
405 let resp = self
406 .auth_put(&url)
407 .send_json(body.clone())
408 .map_err(map_error)?;
409 resp.into_json()
410 .map_err(|e| ClientError::Decode(e.to_string()))
411 }
412
413 pub fn admin_status(&self) -> Result<DaemonStatus> {
416 let url = format!("{}/v1/mdns/admin/status", self.endpoint);
417 let resp = self.auth_get(&url).call().map_err(map_error)?;
418 resp.into_json()
419 .map_err(|e| ClientError::Decode(e.to_string()))
420 }
421
422 pub fn admin_registrations(&self) -> Result<Vec<AdminRegistration>> {
423 let url = format!("{}/v1/mdns/admin/ls", self.endpoint);
424 let resp = self.auth_get(&url).call().map_err(map_error)?;
425 resp.into_json()
426 .map_err(|e| ClientError::Decode(e.to_string()))
427 }
428
429 pub fn admin_inspect(&self, id: &str) -> Result<AdminRegistration> {
430 let url = format!("{}/v1/mdns/admin/inspect/{id}", self.endpoint);
431 let resp = self.auth_get(&url).call().map_err(map_error)?;
432 resp.into_json()
433 .map_err(|e| ClientError::Decode(e.to_string()))
434 }
435
436 pub fn admin_force_unregister(&self, id: &str) -> Result<()> {
437 let url = format!("{}/v1/mdns/admin/unregister/{id}", self.endpoint);
438 self.auth_delete(&url).call().map_err(map_error)?;
439 Ok(())
440 }
441
442 pub fn admin_drain(&self, id: &str) -> Result<()> {
443 let url = format!("{}/v1/mdns/admin/drain/{id}", self.endpoint);
444 self.auth_post(&url).call().map_err(map_error)?;
445 Ok(())
446 }
447
448 pub fn admin_revive(&self, id: &str) -> Result<()> {
449 let url = format!("{}/v1/mdns/admin/revive/{id}", self.endpoint);
450 self.auth_post(&url).call().map_err(map_error)?;
451 Ok(())
452 }
453
454 pub fn shutdown(&self) -> Result<()> {
458 let url = format!("{}/v1/admin/shutdown", self.endpoint);
459 self.auth_post(&url).call().map_err(map_error)?;
460 Ok(())
461 }
462
463 fn stream_agent(&self) -> ureq::Agent {
467 ureq::AgentBuilder::new()
468 .timeout_connect(CONNECT_TIMEOUT)
469 .build()
470 }
471}
472
473pub struct SseStream {
479 reader: BufReader<Box<dyn Read + Send>>,
480}
481
482impl SseStream {
483 fn new(reader: Box<dyn Read + Send>) -> Self {
484 Self {
485 reader: BufReader::new(reader),
486 }
487 }
488}
489
490impl Iterator for SseStream {
491 type Item = Result<serde_json::Value>;
492
493 fn next(&mut self) -> Option<Self::Item> {
494 loop {
495 let mut line = String::new();
496 match self.reader.read_line(&mut line) {
497 Ok(0) => return None,
498 Ok(_) => {
499 let trimmed = line.trim();
500 if let Some(data) = trimmed.strip_prefix("data:") {
501 let data = data.trim_start();
502 if data.is_empty() {
503 continue;
504 }
505 match serde_json::from_str(data) {
506 Ok(json) => return Some(Ok(json)),
507 Err(e) => return Some(Err(ClientError::Decode(e.to_string()))),
508 }
509 }
510 continue;
511 }
512 Err(e) => return Some(Err(ClientError::Transport(e.to_string()))),
513 }
514 }
515 }
516}
517
518fn map_error(e: ureq::Error) -> ClientError {
521 match e {
522 ureq::Error::Status(401, _resp) => ClientError::Unauthorized,
523 ureq::Error::Status(_status, resp) => {
524 let body = resp.into_string().unwrap_or_default();
525 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&body) {
526 let error = json
527 .get("error")
528 .and_then(|v| v.as_str())
529 .unwrap_or("unknown")
530 .to_string();
531 let message = json
532 .get("message")
533 .and_then(|v| v.as_str())
534 .unwrap_or(&body)
535 .to_string();
536 ClientError::Api { error, message }
537 } else {
538 ClientError::Api {
539 error: "http_error".into(),
540 message: body,
541 }
542 }
543 }
544 ureq::Error::Transport(t) => ClientError::Unreachable(t.to_string()),
545 }
546}
547
548fn record_type_str(record_type: RecordType) -> &'static str {
549 match record_type {
550 RecordType::A => "A",
551 RecordType::AAAA => "AAAA",
552 RecordType::ANY => "ANY",
553 _ => "A",
554 }
555}
556
557fn check_kind_str(kind: ServiceCheckKind) -> &'static str {
558 match kind {
559 ServiceCheckKind::Http => "http",
560 ServiceCheckKind::Tcp => "tcp",
561 }
562}
563
564fn extract<T: serde::de::DeserializeOwned>(json: &serde_json::Value, key: &str) -> Result<T> {
565 if let Some(err_val) = json.get("error") {
566 let error = err_val.as_str().unwrap_or("unknown").to_string();
567 let message = json
568 .get("message")
569 .and_then(|m| m.as_str())
570 .unwrap_or("Unknown error")
571 .to_string();
572 return Err(ClientError::Api { error, message });
573 }
574 json.get(key)
575 .ok_or_else(|| ClientError::Decode(format!("Missing '{key}' in response")))
576 .and_then(|v| {
577 serde_json::from_value(v.clone()).map_err(|e| ClientError::Decode(e.to_string()))
578 })
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 fn cursor_stream(input: &str) -> SseStream {
588 let cursor = std::io::Cursor::new(input.as_bytes().to_vec());
589 SseStream::new(Box::new(cursor))
590 }
591
592 #[test]
595 fn unauthorized_displays_actionable_hint() {
596 let err = ClientError::Unauthorized;
597 assert_eq!(
598 err.to_string(),
599 "remote daemon requires a token (pass --token or set KOI_TOKEN)"
600 );
601 assert!(err.is_unauthorized());
602 }
603
604 #[test]
605 fn non_401_api_error_is_not_unauthorized() {
606 let err = ClientError::Api {
607 error: "not_found".into(),
608 message: "nope".into(),
609 };
610 assert!(!err.is_unauthorized());
611 }
612
613 #[test]
616 fn client_new_strips_trailing_slash() {
617 let client = KoiClient::new("http://localhost:5641/");
619 assert!(
620 client.endpoint == "http://127.0.0.1:5641"
621 || client.endpoint == "http://[::1]:5641"
622 || client.endpoint == "http://localhost:5641",
623 "unexpected endpoint: {}",
624 client.endpoint
625 );
626 assert!(!client.endpoint.ends_with("/"));
627 assert!(client.token.is_empty());
628 }
629
630 #[test]
631 fn client_with_token_sets_token() {
632 let client = KoiClient::with_token("http://10.0.0.1:5641", "my-secret-token");
633 assert_eq!(client.endpoint, "http://10.0.0.1:5641");
634 assert_eq!(client.token, "my-secret-token");
635 }
636
637 #[test]
638 fn client_new_preserves_non_localhost() {
639 let client = KoiClient::new("http://10.0.0.1:5641");
640 assert_eq!(client.endpoint, "http://10.0.0.1:5641");
641 }
642
643 #[test]
644 fn client_new_strips_multiple_trailing_slashes() {
645 let client = KoiClient::new("http://localhost:5641///");
646 assert!(!client.endpoint.ends_with("/"));
647 }
648
649 #[test]
652 fn sse_stream_yields_parsed_json() {
653 let input = "data: {\"foo\": 1}\n\n";
654 let mut stream = cursor_stream(input);
655 let item = stream.next().unwrap().unwrap();
656 assert_eq!(item["foo"], 1);
657 }
658
659 #[test]
660 fn sse_stream_skips_empty_lines() {
661 let input = "\n\n\n\n";
662 let mut stream = cursor_stream(input);
663 assert!(stream.next().is_none());
664 }
665
666 #[test]
667 fn sse_stream_skips_non_data_lines() {
668 let input = "event: message\nretry: 1000\n\n";
669 let mut stream = cursor_stream(input);
670 assert!(stream.next().is_none());
671 }
672
673 #[test]
674 fn sse_stream_handles_leading_space() {
675 let input = "data: {\"hello\": \"world\"}\n";
676 let mut stream = cursor_stream(input);
677 let item = stream.next().unwrap().unwrap();
678 assert_eq!(item["hello"], "world");
679 }
680
681 #[test]
682 fn sse_stream_handles_no_space() {
683 let input = "data:{\"hello\":\"world\"}\n";
684 let mut stream = cursor_stream(input);
685 let item = stream.next().unwrap().unwrap();
686 assert_eq!(item["hello"], "world");
687 }
688
689 #[test]
690 fn sse_stream_yields_multiple_events() {
691 let input = "data: {\"n\": 1}\n\ndata: {\"n\": 2}\n\n";
692 let mut stream = cursor_stream(input);
693 let first = stream.next().unwrap().unwrap();
694 let second = stream.next().unwrap().unwrap();
695 assert_eq!(first["n"], 1);
696 assert_eq!(second["n"], 2);
697 }
698
699 #[test]
700 fn sse_stream_returns_none_on_eof() {
701 let input = "data: {\"n\": 1}\n";
702 let mut stream = cursor_stream(input);
703 let _ = stream.next();
704 assert!(stream.next().is_none());
705 }
706
707 #[test]
708 fn sse_stream_decode_error_on_invalid_json() {
709 let input = "data: {bad json}\n";
710 let mut stream = cursor_stream(input);
711 let item = stream.next().unwrap();
712 assert!(item.is_err());
713 }
714
715 #[test]
716 fn sse_stream_transport_error_on_read_failure() {
717 struct BrokenReader;
718 impl Read for BrokenReader {
719 fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
720 Err(std::io::Error::other("boom"))
721 }
722 }
723
724 let stream = SseStream::new(Box::new(BrokenReader));
725 let mut stream = stream;
726 let item = stream.next().unwrap();
727 assert!(item.is_err());
728 }
729}