myst_client/
transport.rs

1pub mod gossip;
2
3use anyhow::Error as AnyhowError;
4use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionStreamResponse};
5use base32::Alphabet;
6use earendil_crypt::{HavenEndpoint, HavenFingerprint};
7use futures::{Stream, StreamExt};
8use hyper_util::rt::TokioIo;
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use std::{
11    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
12    pin::Pin,
13    str::FromStr,
14};
15use thiserror::Error;
16use tokio_socks::tcp::Socks5Stream;
17use tonic::transport::Channel;
18
19use crate::{
20    proto::{myst_service_client::MystServiceClient, GossipRequest, ProxyRequest},
21    task::TaskContext,
22};
23
24use self::gossip::NetworkState;
25
26#[derive(Clone)]
27pub struct RpcClient {
28    inner: MystServiceClient<Channel>,
29}
30
31impl RpcClient {
32    pub async fn new(
33        node_id: NodeId,
34        proxy_addr: Option<SocketAddr>,
35    ) -> Result<Self, TransportError> {
36        eprintln!("Creating new RPC client with node ID: {node_id}");
37
38        let channel = match (node_id, proxy_addr) {
39            (NodeId::Direct(addr), _) => Channel::from_shared(format!("http://{}", addr))
40                .unwrap()
41                .connect()
42                .await
43                .unwrap(),
44            (NodeId::Earendil(haven_endpoint), Some(proxy_addr)) => {
45                let haven_url = format!(
46                    "{}.haven:{}",
47                    haven_endpoint.fingerprint, haven_endpoint.port
48                );
49
50                eprintln!("haven URL: {haven_url}");
51
52                Channel::from_shared(format!("http://{}", haven_url))
53                    .map_err(|e| TransportError::Earendil(e.to_string()))?
54                    .connect_with_connector(tower::service_fn(move |_| {
55                        let proxy = proxy_addr;
56                        eprintln!("PROXY addr: {proxy}");
57                        let target = haven_url.clone();
58
59                        async move {
60                            // NOTE: This creates a new SOCKS5 connection for each request
61                            // maybe we can cache/reuse the connection?
62                            let socks_stream =
63                                Socks5Stream::connect(proxy, target).await.map_err(|e| {
64                                    std::io::Error::new(std::io::ErrorKind::Other, e.to_string())
65                                })?;
66
67                            Ok::<_, std::io::Error>(TokioIo::new(socks_stream.into_inner()))
68                        }
69                    }))
70                    .await
71                    .map_err(|e| TransportError::Earendil(e.to_string()))?
72            }
73            (NodeId::Tor(_), Some(_proxy_addr)) => {
74                return Err(TransportError::Network(
75                    "Tor transport not implemented".into(),
76                ));
77            }
78            (NodeId::Nym(_), Some(_proxy_addr)) => {
79                return Err(TransportError::Network(
80                    "Nym transport not implemented".into(),
81                ));
82            }
83            _ => return Err(TransportError::Network("unknown transport error".into())),
84        };
85
86        Ok(Self {
87            inner: MystServiceClient::new(channel),
88        })
89    }
90
91    pub async fn gossip(&mut self, state: NetworkState) -> anyhow::Result<()> {
92        let bytes = serde_json::to_vec(&state)?;
93        self.inner
94            .gossip(GossipRequest {
95                network_state: bytes,
96            })
97            .await?;
98        Ok(())
99    }
100
101    pub async fn compute_text(
102        &mut self,
103        request: CreateChatCompletionRequest,
104        ctx: TaskContext,
105    ) -> anyhow::Result<
106        Pin<
107            Box<
108                dyn Stream<Item = Result<CreateChatCompletionStreamResponse, anyhow::Error>> + Send,
109            >,
110        >,
111    > {
112        let proxy_request = ProxyRequest {
113            task_ctx: serde_json::to_vec(&ctx)?,
114            request: serde_json::to_vec(&request)?,
115        };
116
117        eprintln!("sending proxy stream request to node: {:?}", &ctx.node_id);
118        let response = self.inner.proxy_stream(proxy_request).await?;
119        let stream = response.into_inner();
120
121        // Transform tonic::Streaming into a standard Stream
122        let transformed_stream = stream.map(|result| {
123            result
124                .map_err(anyhow::Error::from)
125                .and_then(|proxy_response| {
126                    serde_json::from_slice(&proxy_response.response).map_err(anyhow::Error::from)
127                })
128        });
129
130        Ok(Box::pin(transformed_stream))
131    }
132
133    pub async fn generate_image(
134        &mut self,
135        prompt: String,
136        ctx: TaskContext,
137    ) -> anyhow::Result<String> {
138        todo!()
139    }
140}
141
142#[derive(Error, Debug)]
143pub enum TransportError {
144    #[error("earendil connection error: {0}")]
145    Earendil(String),
146    #[error("TCP connection error: {0}")]
147    Direct(String),
148    #[error("Unknown error: {0}")]
149    Network(String),
150    #[error("Socks5 error: {0}")]
151    Socks5(tokio_socks::Error),
152}
153
154impl From<AnyhowError> for TransportError {
155    fn from(err: AnyhowError) -> Self {
156        TransportError::Network(err.to_string())
157    }
158}
159
160impl From<tokio_socks::Error> for TransportError {
161    fn from(err: tokio_socks::Error) -> Self {
162        TransportError::Socks5(err)
163    }
164}
165
166/// ID for uniquely identifying a node in the network.
167/// This wraps the underlying transport address to keep things simple.
168#[derive(Clone, Debug, PartialEq, Eq, Hash)]
169pub enum NodeId {
170    Direct(SocketAddr),
171    Earendil(HavenEndpoint),
172    Tor(String),
173    Nym(String),
174}
175
176impl Serialize for NodeId {
177    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178    where
179        S: Serializer,
180    {
181        serializer.serialize_str(&self.to_string())
182    }
183}
184
185impl<'de> Deserialize<'de> for NodeId {
186    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
187    where
188        D: Deserializer<'de>,
189    {
190        let s = String::deserialize(deserializer)?;
191        s.parse().map_err(serde::de::Error::custom)
192    }
193}
194
195impl std::fmt::Display for NodeId {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        match self {
198            NodeId::Direct(addr) => write!(f, "{}", addr),
199            NodeId::Earendil(endpoint) => write!(f, "{}:{}", endpoint.fingerprint, endpoint.port),
200            NodeId::Tor(addr) => write!(f, "{}", addr),
201            NodeId::Nym(addr) => write!(f, "{}", addr),
202        }
203    }
204}
205
206impl FromStr for NodeId {
207    type Err = anyhow::Error;
208
209    fn from_str(s: &str) -> Result<Self, anyhow::Error> {
210        if s.is_empty() {
211            return Err(anyhow::anyhow!("Empty string"));
212        }
213
214        // Try to parse as socket address first
215        if let Ok(addr) = s.parse::<SocketAddr>() {
216            return Ok(NodeId::Direct(addr));
217        }
218
219        // Try to parse as Earendil endpoint (fingerprint:port)
220        if let Some((fingerprint, port)) = s.split_once(':') {
221            if let (Ok(fingerprint), Ok(port)) =
222                (HavenFingerprint::from_str(fingerprint), port.parse::<u16>())
223            {
224                return Ok(NodeId::Earendil(HavenEndpoint { fingerprint, port }));
225            }
226        }
227
228        // Could add Tor/Nym parsing here if needed
229        Err(anyhow::anyhow!("Invalid node ID format"))
230    }
231}