1use crate::{Error, Result};
29use crate::zap_capnp;
30use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem};
31use futures::io::{BufReader, BufWriter};
32use serde_json::Value;
33use std::net::ToSocketAddrs;
34use tokio::net::TcpStream;
35use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
36use url::Url;
37
38fn text_to_string(reader: capnp::text::Reader<'_>) -> Result<String> {
40 reader.to_str()
41 .map(|s| s.to_string())
42 .map_err(|e| Error::Protocol(format!("invalid UTF-8: {}", e)))
43}
44
45#[derive(Debug, Clone)]
47pub struct ClientInfo {
48 pub name: String,
49 pub version: String,
50}
51
52#[derive(Debug, Clone)]
54pub struct ServerInfo {
55 pub name: String,
56 pub version: String,
57 pub capabilities: ServerCapabilities,
58}
59
60#[derive(Debug, Clone, Default)]
62pub struct ServerCapabilities {
63 pub tools: bool,
64 pub resources: bool,
65 pub prompts: bool,
66 pub logging: bool,
67}
68
69#[derive(Debug, Clone)]
71pub struct Tool {
72 pub name: String,
73 pub description: String,
74 pub schema: Value,
75}
76
77#[derive(Debug, Clone)]
79pub struct Resource {
80 pub uri: String,
81 pub name: String,
82 pub description: String,
83 pub mime_type: String,
84}
85
86#[derive(Debug, Clone)]
88pub struct ResourceContent {
89 pub uri: String,
90 pub mime_type: String,
91 pub content: Content,
92}
93
94#[derive(Debug, Clone)]
96pub enum Content {
97 Text(String),
98 Blob(Vec<u8>),
99}
100
101#[derive(Debug, Clone)]
103pub struct Prompt {
104 pub name: String,
105 pub description: String,
106 pub arguments: Vec<PromptArgument>,
107}
108
109#[derive(Debug, Clone)]
111pub struct PromptArgument {
112 pub name: String,
113 pub description: String,
114 pub required: bool,
115}
116
117#[derive(Debug, Clone)]
119pub struct PromptMessage {
120 pub role: Role,
121 pub content: MessageContent,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum Role {
127 User,
128 Assistant,
129 System,
130}
131
132#[derive(Debug, Clone)]
134pub enum MessageContent {
135 Text(String),
136 Image { data: Vec<u8>, mime_type: String },
137 Resource(ResourceContent),
138}
139
140pub struct ResourceStream {
142 stream_client: zap_capnp::resource_stream::Client,
143}
144
145impl ResourceStream {
146 fn new(client: zap_capnp::resource_stream::Client) -> Self {
147 Self { stream_client: client }
148 }
149
150 pub async fn next(&self) -> Result<Option<ResourceContent>> {
152 let request = self.stream_client.next_request();
153 let response = request.send().promise.await
154 .map_err(|e| Error::Protocol(format!("stream next failed: {}", e)))?;
155
156 let results = response.get()
157 .map_err(|e| Error::Protocol(format!("failed to get results: {}", e)))?;
158
159 if results.get_done() {
160 return Ok(None);
161 }
162
163 let content = results.get_content()
164 .map_err(|e| Error::Protocol(format!("failed to get content: {}", e)))?;
165
166 Ok(Some(convert_resource_content(content)?))
167 }
168
169 pub async fn cancel(&self) -> Result<()> {
171 let request = self.stream_client.cancel_request();
172 request.send().promise.await
173 .map_err(|e| Error::Protocol(format!("stream cancel failed: {}", e)))?;
174 Ok(())
175 }
176}
177
178pub struct Client {
183 zap_client: zap_capnp::zap::Client,
185 disconnector: capnp_rpc::Disconnector<rpc_twoparty_capnp::Side>,
187}
188
189impl Client {
190 pub async fn connect(url: &str) -> Result<Self> {
202 let parsed = Url::parse(url)?;
203
204 match parsed.scheme() {
205 "zap" | "zap+tcp" | "tcp" => {
206 let host = parsed.host_str().unwrap_or("localhost");
207 let port = parsed.port().unwrap_or(crate::DEFAULT_PORT);
208 let addr = format!("{}:{}", host, port);
209 Self::connect_tcp(&addr).await
210 }
211 scheme => Err(Error::Connection(format!(
212 "unsupported URL scheme '{}' - use zap://, zap+tcp://, or tcp://",
213 scheme
214 ))),
215 }
216 }
217
218 pub async fn connect_tcp(addr: &str) -> Result<Self> {
226 let socket_addr = addr
227 .to_socket_addrs()
228 .map_err(|e| Error::Connection(format!("invalid address '{}': {}", addr, e)))?
229 .next()
230 .ok_or_else(|| Error::Connection(format!("could not resolve address '{}'", addr)))?;
231
232 let stream = TcpStream::connect(&socket_addr)
233 .await
234 .map_err(|e| Error::Connection(format!("failed to connect to {}: {}", addr, e)))?;
235
236 stream.set_nodelay(true)
237 .map_err(|e| Error::Connection(format!("failed to set TCP_NODELAY: {}", e)))?;
238
239 Self::from_tcp_stream(stream).await
240 }
241
242 pub async fn from_tcp_stream(stream: TcpStream) -> Result<Self> {
246 let (reader, writer) = stream.into_split();
248
249 let reader = BufReader::new(reader.compat());
251 let writer = BufWriter::new(writer.compat_write());
252
253 let network = Box::new(twoparty::VatNetwork::new(
255 reader,
256 writer,
257 rpc_twoparty_capnp::Side::Client,
258 Default::default(),
259 ));
260
261 let mut rpc_system = RpcSystem::new(network, None);
263
264 let disconnector = rpc_system.get_disconnector();
266
267 let zap_client: zap_capnp::zap::Client =
269 rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server);
270
271 tokio::task::spawn_local(rpc_system);
274
275 Ok(Self {
276 zap_client,
277 disconnector,
278 })
279 }
280
281 pub async fn init(&self, name: &str, version: &str) -> Result<ServerInfo> {
295 let mut request = self.zap_client.init_request();
296 {
297 let mut client_info = request.get().init_client();
298 client_info.set_name(name);
299 client_info.set_version(version);
300 }
301
302 let response = request.send().promise.await
303 .map_err(|e| Error::Protocol(format!("init failed: {}", e)))?;
304
305 let results = response.get()
306 .map_err(|e| Error::Protocol(format!("failed to get init results: {}", e)))?;
307
308 let server = results.get_server()
309 .map_err(|e| Error::Protocol(format!("failed to get server info: {}", e)))?;
310
311 let caps = server.get_capabilities()
312 .map_err(|e| Error::Protocol(format!("failed to get capabilities: {}", e)))?;
313
314 let name_reader = server.get_name()
315 .map_err(|e| Error::Protocol(format!("failed to get server name: {}", e)))?;
316 let version_reader = server.get_version()
317 .map_err(|e| Error::Protocol(format!("failed to get server version: {}", e)))?;
318
319 Ok(ServerInfo {
320 name: text_to_string(name_reader)?,
321 version: text_to_string(version_reader)?,
322 capabilities: ServerCapabilities {
323 tools: caps.get_tools(),
324 resources: caps.get_resources(),
325 prompts: caps.get_prompts(),
326 logging: caps.get_logging(),
327 },
328 })
329 }
330
331 pub async fn list_tools(&self) -> Result<Vec<Tool>> {
342 let request = self.zap_client.list_tools_request();
343 let response = request.send().promise.await
344 .map_err(|e| Error::Protocol(format!("list_tools failed: {}", e)))?;
345
346 let results = response.get()
347 .map_err(|e| Error::Protocol(format!("failed to get list_tools results: {}", e)))?;
348
349 let tool_list = results.get_tools()
350 .map_err(|e| Error::Protocol(format!("failed to get tool list: {}", e)))?;
351
352 let tools = tool_list.get_tools()
353 .map_err(|e| Error::Protocol(format!("failed to get tools: {}", e)))?;
354
355 let mut result = Vec::with_capacity(tools.len() as usize);
356 for tool in tools.iter() {
357 let name_reader = tool.get_name()
358 .map_err(|e| Error::Protocol(format!("failed to get tool name: {}", e)))?;
359 let desc_reader = tool.get_description()
360 .map_err(|e| Error::Protocol(format!("failed to get tool description: {}", e)))?;
361 let schema_bytes = tool.get_schema()
362 .map_err(|e| Error::Protocol(format!("failed to get tool schema: {}", e)))?;
363 let schema: Value = if schema_bytes.is_empty() {
364 Value::Object(serde_json::Map::new())
365 } else {
366 serde_json::from_slice(schema_bytes)
367 .map_err(|e| Error::Protocol(format!("failed to parse tool schema: {}", e)))?
368 };
369
370 result.push(Tool {
371 name: text_to_string(name_reader)?,
372 description: text_to_string(desc_reader)?,
373 schema,
374 });
375 }
376
377 Ok(result)
378 }
379
380 pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value> {
398 self.call_tool_with_id(uuid_v4(), name, args).await
399 }
400
401 pub async fn call_tool_with_id(&self, id: &str, name: &str, args: Value) -> Result<Value> {
405 let args_bytes = serde_json::to_vec(&args)?;
406
407 let mut request = self.zap_client.call_tool_request();
408 {
409 let mut call = request.get().init_call();
410 call.set_id(id);
411 call.set_name(name);
412 call.set_args(&args_bytes);
413 }
414
415 let response = request.send().promise.await
416 .map_err(|e| Error::Protocol(format!("call_tool failed: {}", e)))?;
417
418 let results = response.get()
419 .map_err(|e| Error::Protocol(format!("failed to get call_tool results: {}", e)))?;
420
421 let tool_result = results.get_result()
422 .map_err(|e| Error::Protocol(format!("failed to get tool result: {}", e)))?;
423
424 let error_reader = tool_result.get_error()
426 .map_err(|e| Error::Protocol(format!("failed to get error field: {}", e)))?;
427 if !error_reader.is_empty() {
428 return Err(Error::ToolCallFailed(text_to_string(error_reader)?));
429 }
430
431 let content_bytes = tool_result.get_content()
433 .map_err(|e| Error::Protocol(format!("failed to get content: {}", e)))?;
434
435 if content_bytes.is_empty() {
436 Ok(Value::Null)
437 } else {
438 serde_json::from_slice(content_bytes)
439 .map_err(|e| Error::Protocol(format!("failed to parse tool result: {}", e)))
440 }
441 }
442
443 pub async fn list_resources(&self) -> Result<Vec<Resource>> {
455 let request = self.zap_client.list_resources_request();
456 let response = request.send().promise.await
457 .map_err(|e| Error::Protocol(format!("list_resources failed: {}", e)))?;
458
459 let results = response.get()
460 .map_err(|e| Error::Protocol(format!("failed to get list_resources results: {}", e)))?;
461
462 let resource_list = results.get_resources()
463 .map_err(|e| Error::Protocol(format!("failed to get resource list: {}", e)))?;
464
465 let resources = resource_list.get_resources()
466 .map_err(|e| Error::Protocol(format!("failed to get resources: {}", e)))?;
467
468 let mut result = Vec::with_capacity(resources.len() as usize);
469 for resource in resources.iter() {
470 let uri_reader = resource.get_uri()
471 .map_err(|e| Error::Protocol(format!("failed to get resource uri: {}", e)))?;
472 let name_reader = resource.get_name()
473 .map_err(|e| Error::Protocol(format!("failed to get resource name: {}", e)))?;
474 let desc_reader = resource.get_description()
475 .map_err(|e| Error::Protocol(format!("failed to get resource description: {}", e)))?;
476 let mime_reader = resource.get_mime_type()
477 .map_err(|e| Error::Protocol(format!("failed to get resource mime_type: {}", e)))?;
478
479 result.push(Resource {
480 uri: text_to_string(uri_reader)?,
481 name: text_to_string(name_reader)?,
482 description: text_to_string(desc_reader)?,
483 mime_type: text_to_string(mime_reader)?,
484 });
485 }
486
487 Ok(result)
488 }
489
490 pub async fn read_resource(&self, uri: &str) -> Result<ResourceContent> {
502 let mut request = self.zap_client.read_resource_request();
503 request.get().set_uri(uri);
504
505 let response = request.send().promise.await
506 .map_err(|e| Error::Protocol(format!("read_resource failed: {}", e)))?;
507
508 let results = response.get()
509 .map_err(|e| Error::Protocol(format!("failed to get read_resource results: {}", e)))?;
510
511 let content = results.get_content()
512 .map_err(|e| Error::Protocol(format!("failed to get content: {}", e)))?;
513
514 convert_resource_content(content)
515 }
516
517 pub async fn subscribe(&self, uri: &str) -> Result<ResourceStream> {
530 let mut request = self.zap_client.subscribe_request();
531 request.get().set_uri(uri);
532
533 let response = request.send().promise.await
534 .map_err(|e| Error::Protocol(format!("subscribe failed: {}", e)))?;
535
536 let results = response.get()
537 .map_err(|e| Error::Protocol(format!("failed to get subscribe results: {}", e)))?;
538
539 let stream_client = results.get_stream()
540 .map_err(|e| Error::Protocol(format!("failed to get stream: {}", e)))?;
541
542 Ok(ResourceStream::new(stream_client))
543 }
544
545 pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
556 let request = self.zap_client.list_prompts_request();
557 let response = request.send().promise.await
558 .map_err(|e| Error::Protocol(format!("list_prompts failed: {}", e)))?;
559
560 let results = response.get()
561 .map_err(|e| Error::Protocol(format!("failed to get list_prompts results: {}", e)))?;
562
563 let prompt_list = results.get_prompts()
564 .map_err(|e| Error::Protocol(format!("failed to get prompt list: {}", e)))?;
565
566 let prompts = prompt_list.get_prompts()
567 .map_err(|e| Error::Protocol(format!("failed to get prompts: {}", e)))?;
568
569 let mut result = Vec::with_capacity(prompts.len() as usize);
570 for prompt in prompts.iter() {
571 let arguments = prompt.get_arguments()
572 .map_err(|e| Error::Protocol(format!("failed to get prompt arguments: {}", e)))?;
573
574 let mut args = Vec::with_capacity(arguments.len() as usize);
575 for arg in arguments.iter() {
576 let arg_name = arg.get_name()
577 .map_err(|e| Error::Protocol(format!("failed to get arg name: {}", e)))?;
578 let arg_desc = arg.get_description()
579 .map_err(|e| Error::Protocol(format!("failed to get arg description: {}", e)))?;
580 args.push(PromptArgument {
581 name: text_to_string(arg_name)?,
582 description: text_to_string(arg_desc)?,
583 required: arg.get_required(),
584 });
585 }
586
587 let prompt_name = prompt.get_name()
588 .map_err(|e| Error::Protocol(format!("failed to get prompt name: {}", e)))?;
589 let prompt_desc = prompt.get_description()
590 .map_err(|e| Error::Protocol(format!("failed to get prompt description: {}", e)))?;
591
592 result.push(Prompt {
593 name: text_to_string(prompt_name)?,
594 description: text_to_string(prompt_desc)?,
595 arguments: args,
596 });
597 }
598
599 Ok(result)
600 }
601
602 pub async fn get_prompt(&self, name: &str, args: &[(&str, &str)]) -> Result<Vec<PromptMessage>> {
621 let mut request = self.zap_client.get_prompt_request();
622 {
623 let mut params = request.get();
624 params.set_name(name);
625
626 let mut metadata = params.init_args();
627 let mut entries = metadata.init_entries(args.len() as u32);
628 for (i, (key, value)) in args.iter().enumerate() {
629 let mut entry = entries.reborrow().get(i as u32);
630 entry.set_key(*key);
631 entry.set_value(*value);
632 }
633 }
634
635 let response = request.send().promise.await
636 .map_err(|e| Error::Protocol(format!("get_prompt failed: {}", e)))?;
637
638 let results = response.get()
639 .map_err(|e| Error::Protocol(format!("failed to get get_prompt results: {}", e)))?;
640
641 let messages = results.get_messages()
642 .map_err(|e| Error::Protocol(format!("failed to get messages: {}", e)))?;
643
644 let mut result = Vec::with_capacity(messages.len() as usize);
645 for msg in messages.iter() {
646 let role = match msg.get_role()
647 .map_err(|e| Error::Protocol(format!("failed to get role: {}", e)))?
648 {
649 zap_capnp::prompt_message::Role::User => Role::User,
650 zap_capnp::prompt_message::Role::Assistant => Role::Assistant,
651 zap_capnp::prompt_message::Role::System => Role::System,
652 };
653
654 let content_reader = msg.get_content()
655 .map_err(|e| Error::Protocol(format!("failed to get content: {}", e)))?;
656
657 let content = match content_reader.which()
658 .map_err(|e| Error::Protocol(format!("failed to get content type: {}", e)))?
659 {
660 zap_capnp::prompt_message::content::Which::Text(text_reader) => {
661 let text_reader = text_reader
662 .map_err(|e| Error::Protocol(format!("failed to get text: {}", e)))?;
663 MessageContent::Text(text_to_string(text_reader)?)
664 }
665 zap_capnp::prompt_message::content::Which::Image(image) => {
666 let image = image
667 .map_err(|e| Error::Protocol(format!("failed to get image: {}", e)))?;
668 let mime_reader = image.get_mime_type()
669 .map_err(|e| Error::Protocol(format!("failed to get image mime_type: {}", e)))?;
670 MessageContent::Image {
671 data: image.get_data()
672 .map_err(|e| Error::Protocol(format!("failed to get image data: {}", e)))?
673 .to_vec(),
674 mime_type: text_to_string(mime_reader)?,
675 }
676 }
677 zap_capnp::prompt_message::content::Which::Resource(resource) => {
678 let resource = resource
679 .map_err(|e| Error::Protocol(format!("failed to get resource: {}", e)))?;
680 MessageContent::Resource(convert_resource_content(resource)?)
681 }
682 };
683
684 result.push(PromptMessage { role, content });
685 }
686
687 Ok(result)
688 }
689
690 pub async fn log(&self, level: LogLevel, message: &str, data: Option<Value>) -> Result<()> {
709 let mut request = self.zap_client.log_request();
710 {
711 let mut params = request.get();
712 params.set_level(match level {
713 LogLevel::Debug => zap_capnp::zap::LogLevel::Debug,
714 LogLevel::Info => zap_capnp::zap::LogLevel::Info,
715 LogLevel::Warn => zap_capnp::zap::LogLevel::Warn,
716 LogLevel::Error => zap_capnp::zap::LogLevel::Error,
717 });
718 params.set_message(message);
719 if let Some(data) = data {
720 let data_bytes = serde_json::to_vec(&data)?;
721 params.set_data(&data_bytes);
722 }
723 }
724
725 request.send().promise.await
726 .map_err(|e| Error::Protocol(format!("log failed: {}", e)))?;
727
728 Ok(())
729 }
730
731 pub async fn disconnect(self) -> Result<()> {
735 self.disconnector.await
736 .map_err(|e| Error::Connection(format!("disconnect failed: {}", e)))
737 }
738}
739
740#[derive(Debug, Clone, Copy, PartialEq, Eq)]
742pub enum LogLevel {
743 Debug,
744 Info,
745 Warn,
746 Error,
747}
748
749fn convert_resource_content(
751 content: zap_capnp::resource_content::Reader<'_>
752) -> Result<ResourceContent> {
753 let uri_reader = content.get_uri()
754 .map_err(|e| Error::Protocol(format!("failed to get uri: {}", e)))?;
755 let uri = uri_reader.to_str()
756 .map_err(|e| Error::Protocol(format!("invalid utf8 in uri: {}", e)))?
757 .to_string();
758
759 let mime_reader = content.get_mime_type()
760 .map_err(|e| Error::Protocol(format!("failed to get mime_type: {}", e)))?;
761 let mime_type = mime_reader.to_str()
762 .map_err(|e| Error::Protocol(format!("invalid utf8 in mime_type: {}", e)))?
763 .to_string();
764
765 let content_data = match content.get_content().which()
766 .map_err(|e| Error::Protocol(format!("failed to get content type: {}", e)))?
767 {
768 zap_capnp::resource_content::content::Which::Text(text) => {
769 let text_reader = text
770 .map_err(|e| Error::Protocol(format!("failed to get text: {}", e)))?;
771 let text_str = text_reader.to_str()
772 .map_err(|e| Error::Protocol(format!("invalid utf8 in text: {}", e)))?;
773 Content::Text(text_str.to_string())
774 }
775 zap_capnp::resource_content::content::Which::Blob(blob) => {
776 let blob_data = blob
777 .map_err(|e| Error::Protocol(format!("failed to get blob: {}", e)))?;
778 Content::Blob(blob_data.to_vec())
779 }
780 };
781
782 Ok(ResourceContent {
783 uri,
784 mime_type,
785 content: content_data,
786 })
787}
788
789fn uuid_v4() -> &'static str {
791 "00000000-0000-0000-0000-000000000000"
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799
800 #[test]
801 fn test_log_level_conversion() {
802 assert_eq!(LogLevel::Debug as u8, 0);
803 assert_eq!(LogLevel::Info as u8, 1);
804 assert_eq!(LogLevel::Warn as u8, 2);
805 assert_eq!(LogLevel::Error as u8, 3);
806 }
807
808 #[test]
809 fn test_content_debug() {
810 let text = Content::Text("hello".to_string());
811 let blob = Content::Blob(vec![1, 2, 3]);
812
813 let _ = format!("{:?}", text);
815 let _ = format!("{:?}", blob);
816 }
817
818 #[test]
819 fn test_role_equality() {
820 assert_eq!(Role::User, Role::User);
821 assert_ne!(Role::User, Role::Assistant);
822 }
823}