1use anyhow::{anyhow, bail};
2use quic_rpc::{transport::mem, Service};
3use serde_with::{DeserializeFromStr, SerializeDisplay};
4use std::{
5 fmt::{Debug, Display},
6 net::SocketAddr,
7 str::FromStr,
8};
9
10#[derive(SerializeDisplay, DeserializeFromStr)]
13pub enum Addr<S: Service> {
14 Irpc(SocketAddr),
15 IrpcLookup(String),
16 Mem(
17 mem::ServerChannel<S::Req, S::Res>,
18 mem::ClientChannel<S::Res, S::Req>,
19 ),
20}
21
22impl<S: Service> PartialEq for Addr<S> {
23 fn eq(&self, other: &Self) -> bool {
24 match (self, other) {
25 (Self::Irpc(addr1), Self::Irpc(addr2)) => addr1.eq(addr2),
26 (Self::IrpcLookup(addr1), Self::IrpcLookup(addr2)) => addr1.eq(addr2),
27 _ => false,
28 }
29 }
30}
31
32impl<S: Service> Addr<S> {
33 pub fn new_mem() -> Self {
34 let (server, client) = mem::connection(256);
35
36 Self::Mem(server, client)
37 }
38}
39
40impl<S: Service> Addr<S> {
41 pub fn try_as_socket_addr(&self) -> Option<SocketAddr> {
42 if let Addr::Irpc(addr) = self {
43 return Some(*addr);
44 }
45 None
46 }
47}
48
49impl<S: Service> Display for Addr<S> {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Addr::Irpc(addr) => write!(f, "irpc://{addr}"),
53 Addr::IrpcLookup(addr) => write!(f, "irpc://{addr}"),
54 Addr::Mem(_, _) => write!(f, "mem"),
55 #[allow(unreachable_patterns)]
56 _ => unreachable!(),
57 }
58 }
59}
60
61impl<S: Service> Debug for Addr<S> {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 Display::fmt(self, f)
64 }
65}
66
67impl<S: Service> Clone for Addr<S> {
68 fn clone(&self) -> Self {
69 match self {
70 Addr::Irpc(addr) => Addr::Irpc(*addr),
71 Addr::IrpcLookup(addr) => Addr::IrpcLookup(addr.clone()),
72 Addr::Mem(server, client) => Addr::Mem(server.clone(), client.clone()),
73 }
74 }
75}
76
77impl<S: Service> FromStr for Addr<S> {
78 type Err = anyhow::Error;
79
80 fn from_str(s: &str) -> Result<Self, Self::Err> {
81 if s == "mem" {
82 bail!("memory addresses can not be serialized or deserialized");
83 }
84
85 let mut parts = s.splitn(2, "://");
86 if let Some(prefix) = parts.next() {
87 if prefix == "irpc" {
88 if let Some(part) = parts.next() {
89 return Ok(if let Ok(addr) = part.parse() {
90 Addr::Irpc(addr)
91 } else {
92 Addr::IrpcLookup(part.to_string())
93 });
94 }
95 }
96 }
97
98 Err(anyhow!("invalid addr: {}", s))
99 }
100}
101
102#[cfg(test)]
103mod tests {
104
105 #[test]
106 fn test_addr_roundtrip_irpc_http2() {
107 use crate::gateway::GatewayAddr;
108 use crate::Addr;
109 use std::net::SocketAddr;
110
111 let socket: SocketAddr = "198.168.2.1:1234".parse().unwrap();
112 let addr = Addr::Irpc(socket);
113
114 assert_eq!(addr.to_string().parse::<GatewayAddr>().unwrap(), addr);
115 assert_eq!(addr.to_string(), "irpc://198.168.2.1:1234");
116 }
117}