1pub mod gossip;
2
3use anyhow::Error as AnyhowError;
4use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionStreamResponse};
5use base32::Alphabet;
6use earendil_crypt::{HavenEndpoint, HavenFingerprint};
7use futures::{Stream, StreamExt};
8use hyper_util::rt::TokioIo;
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use std::{
11 net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
12 pin::Pin,
13 str::FromStr,
14};
15use thiserror::Error;
16use tokio_socks::tcp::Socks5Stream;
17use tonic::transport::Channel;
18
19use crate::{
20 proto::{myst_service_client::MystServiceClient, GossipRequest, ProxyRequest},
21 task::TaskContext,
22};
23
24use self::gossip::NetworkState;
25
26#[derive(Clone)]
27pub struct RpcClient {
28 inner: MystServiceClient<Channel>,
29}
30
31impl RpcClient {
32 pub async fn new(
33 node_id: NodeId,
34 proxy_addr: Option<SocketAddr>,
35 ) -> Result<Self, TransportError> {
36 eprintln!("Creating new RPC client with node ID: {node_id}");
37
38 let channel = match (node_id, proxy_addr) {
39 (NodeId::Direct(addr), _) => Channel::from_shared(format!("http://{}", addr))
40 .unwrap()
41 .connect()
42 .await
43 .unwrap(),
44 (NodeId::Earendil(haven_endpoint), Some(proxy_addr)) => {
45 let haven_url = format!(
46 "{}.haven:{}",
47 haven_endpoint.fingerprint, haven_endpoint.port
48 );
49
50 eprintln!("haven URL: {haven_url}");
51
52 Channel::from_shared(format!("http://{}", haven_url))
53 .map_err(|e| TransportError::Earendil(e.to_string()))?
54 .connect_with_connector(tower::service_fn(move |_| {
55 let proxy = proxy_addr;
56 eprintln!("PROXY addr: {proxy}");
57 let target = haven_url.clone();
58
59 async move {
60 let socks_stream =
63 Socks5Stream::connect(proxy, target).await.map_err(|e| {
64 std::io::Error::new(std::io::ErrorKind::Other, e.to_string())
65 })?;
66
67 Ok::<_, std::io::Error>(TokioIo::new(socks_stream.into_inner()))
68 }
69 }))
70 .await
71 .map_err(|e| TransportError::Earendil(e.to_string()))?
72 }
73 (NodeId::Tor(_), Some(_proxy_addr)) => {
74 return Err(TransportError::Network(
75 "Tor transport not implemented".into(),
76 ));
77 }
78 (NodeId::Nym(_), Some(_proxy_addr)) => {
79 return Err(TransportError::Network(
80 "Nym transport not implemented".into(),
81 ));
82 }
83 _ => return Err(TransportError::Network("unknown transport error".into())),
84 };
85
86 Ok(Self {
87 inner: MystServiceClient::new(channel),
88 })
89 }
90
91 pub async fn gossip(&mut self, state: NetworkState) -> anyhow::Result<()> {
92 let bytes = serde_json::to_vec(&state)?;
93 self.inner
94 .gossip(GossipRequest {
95 network_state: bytes,
96 })
97 .await?;
98 Ok(())
99 }
100
101 pub async fn compute_text(
102 &mut self,
103 request: CreateChatCompletionRequest,
104 ctx: TaskContext,
105 ) -> anyhow::Result<
106 Pin<
107 Box<
108 dyn Stream<Item = Result<CreateChatCompletionStreamResponse, anyhow::Error>> + Send,
109 >,
110 >,
111 > {
112 let proxy_request = ProxyRequest {
113 task_ctx: serde_json::to_vec(&ctx)?,
114 request: serde_json::to_vec(&request)?,
115 };
116
117 eprintln!("sending proxy stream request to node: {:?}", &ctx.node_id);
118 let response = self.inner.proxy_stream(proxy_request).await?;
119 let stream = response.into_inner();
120
121 let transformed_stream = stream.map(|result| {
123 result
124 .map_err(anyhow::Error::from)
125 .and_then(|proxy_response| {
126 serde_json::from_slice(&proxy_response.response).map_err(anyhow::Error::from)
127 })
128 });
129
130 Ok(Box::pin(transformed_stream))
131 }
132
133 pub async fn generate_image(
134 &mut self,
135 prompt: String,
136 ctx: TaskContext,
137 ) -> anyhow::Result<String> {
138 todo!()
139 }
140}
141
142#[derive(Error, Debug)]
143pub enum TransportError {
144 #[error("earendil connection error: {0}")]
145 Earendil(String),
146 #[error("TCP connection error: {0}")]
147 Direct(String),
148 #[error("Unknown error: {0}")]
149 Network(String),
150 #[error("Socks5 error: {0}")]
151 Socks5(tokio_socks::Error),
152}
153
154impl From<AnyhowError> for TransportError {
155 fn from(err: AnyhowError) -> Self {
156 TransportError::Network(err.to_string())
157 }
158}
159
160impl From<tokio_socks::Error> for TransportError {
161 fn from(err: tokio_socks::Error) -> Self {
162 TransportError::Socks5(err)
163 }
164}
165
166#[derive(Clone, Debug, PartialEq, Eq, Hash)]
169pub enum NodeId {
170 Direct(SocketAddr),
171 Earendil(HavenEndpoint),
172 Tor(String),
173 Nym(String),
174}
175
176impl Serialize for NodeId {
177 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178 where
179 S: Serializer,
180 {
181 serializer.serialize_str(&self.to_string())
182 }
183}
184
185impl<'de> Deserialize<'de> for NodeId {
186 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
187 where
188 D: Deserializer<'de>,
189 {
190 let s = String::deserialize(deserializer)?;
191 s.parse().map_err(serde::de::Error::custom)
192 }
193}
194
195impl std::fmt::Display for NodeId {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 match self {
198 NodeId::Direct(addr) => write!(f, "{}", addr),
199 NodeId::Earendil(endpoint) => write!(f, "{}:{}", endpoint.fingerprint, endpoint.port),
200 NodeId::Tor(addr) => write!(f, "{}", addr),
201 NodeId::Nym(addr) => write!(f, "{}", addr),
202 }
203 }
204}
205
206impl FromStr for NodeId {
207 type Err = anyhow::Error;
208
209 fn from_str(s: &str) -> Result<Self, anyhow::Error> {
210 if s.is_empty() {
211 return Err(anyhow::anyhow!("Empty string"));
212 }
213
214 if let Ok(addr) = s.parse::<SocketAddr>() {
216 return Ok(NodeId::Direct(addr));
217 }
218
219 if let Some((fingerprint, port)) = s.split_once(':') {
221 if let (Ok(fingerprint), Ok(port)) =
222 (HavenFingerprint::from_str(fingerprint), port.parse::<u16>())
223 {
224 return Ok(NodeId::Earendil(HavenEndpoint { fingerprint, port }));
225 }
226 }
227
228 Err(anyhow::anyhow!("Invalid node ID format"))
230 }
231}