modality_ingest_client/
client.rs

1use modality_api::types::{AttrKey, AttrVal, TimelineId};
2use modality_ingest_protocol::{IngestMessage, IngestResponse, InternedAttrKey, PackedAttrKvs};
3use std::{net::SocketAddr, path::PathBuf, time::Duration};
4use thiserror::Error;
5use tokio::{
6    io::{AsyncReadExt, AsyncWriteExt},
7    net::{TcpSocket, TcpStream},
8    time::timeout,
9};
10use tokio_native_tls::TlsStream;
11use url::Url;
12
13pub struct IngestClient<S> {
14    #[allow(unused)]
15    pub(crate) state: S,
16    pub(crate) common: IngestClientCommon,
17}
18
19pub struct UnauthenticatedState {}
20pub struct ReadyState {}
21pub struct BoundTimelineState {
22    pub(crate) timeline_id: TimelineId,
23}
24
25/// Fields used by the client in every state
26#[doc(hidden)]
27pub struct IngestClientCommon {
28    pub timeout: Duration,
29    connection: IngestConnection,
30    next_id: u32,
31}
32
33impl IngestClientCommon {
34    #[doc(hidden)]
35    pub fn new(timeout: Duration, connection: IngestConnection) -> Self {
36        IngestClientCommon {
37            timeout,
38            connection,
39            next_id: 0,
40        }
41    }
42
43    /// Send a message and wait for a required response.
44    #[doc(hidden)]
45    pub async fn send_recv(&mut self, msg: &IngestMessage) -> Result<IngestResponse, IngestError> {
46        self.connection.write_msg(msg).await?;
47        timeout(self.timeout, self.connection.read_msg()).await?
48    }
49
50    /// Send a message.
51    #[doc(hidden)]
52    pub async fn send(&mut self, msg: &IngestMessage) -> Result<(), IngestError> {
53        self.connection.write_msg(msg).await
54    }
55
56    pub(crate) async fn declare_attr_key<K: Into<AttrKey>>(
57        &mut self,
58        key_name: K,
59    ) -> Result<InternedAttrKey, IngestError> {
60        let key_name = key_name.into();
61
62        if !(key_name.as_ref().starts_with("timeline.") || key_name.as_ref().starts_with("event."))
63        {
64            return Err(IngestError::AttrKeyNaming);
65        }
66
67        let wire_id = self.next_id;
68        self.next_id += 1;
69        let wire_id = wire_id.into();
70
71        self.send(&IngestMessage::DeclareAttrKey {
72            name: key_name.into(),
73            wire_id,
74        })
75        .await?;
76
77        Ok(wire_id)
78    }
79}
80
81#[derive(Copy, Clone)]
82pub enum TlsMode {
83    Secure,
84    Insecure,
85}
86
87pub enum IngestConnection {
88    Tcp(TcpStream),
89    Tls(TlsStream<TcpStream>),
90}
91
92impl IngestConnection {
93    pub async fn connect(
94        endpoint: &Url,
95        allow_insecure_tls: bool,
96    ) -> Result<IngestConnection, IngestClientInitializationError> {
97        let endpoint = IngestEndpoint::parse_and_resolve(endpoint, allow_insecure_tls).await?;
98
99        // take the first addr, arbitrarily
100        let remote_addr = endpoint
101            .addrs
102            .into_iter()
103            .next()
104            .ok_or(IngestClientInitializationError::NoIps)?;
105
106        let local_addr: SocketAddr = if remote_addr.is_ipv4() {
107            "0.0.0.0:0"
108        } else {
109            "[::]:0"
110        }
111        .parse()?;
112
113        let socket = if remote_addr.is_ipv4() {
114            TcpSocket::new_v4().map_err(IngestClientInitializationError::SocketInit)?
115        } else {
116            TcpSocket::new_v6().map_err(IngestClientInitializationError::SocketInit)?
117        };
118
119        socket
120            .bind(local_addr)
121            .map_err(IngestClientInitializationError::SocketInit)?;
122        let stream = socket.connect(remote_addr).await.map_err(|error| {
123            IngestClientInitializationError::SocketConnection { error, remote_addr }
124        })?;
125
126        if let Some(tls_mode) = endpoint.tls_mode {
127            let cx = native_tls::TlsConnector::builder()
128                .danger_accept_invalid_certs(match tls_mode {
129                    TlsMode::Secure => false,
130                    TlsMode::Insecure => true,
131                })
132                .build()?;
133            let cx = tokio_native_tls::TlsConnector::from(cx);
134            let stream = cx.connect(&endpoint.cert_domain, stream).await?;
135            Ok(IngestConnection::Tls(stream))
136        } else {
137            Ok(IngestConnection::Tcp(stream))
138        }
139    }
140
141    async fn write_msg(&mut self, msg: &IngestMessage) -> Result<(), IngestError> {
142        let msg_buf = minicbor::to_vec(msg)?;
143        let msg_len = msg_buf.len() as u32;
144
145        match self {
146            IngestConnection::Tcp(s) => {
147                s.write_all(&msg_len.to_be_bytes())
148                    .await
149                    .map_err(minicbor::encode::Error::Write)?;
150                s.write_all(&msg_buf)
151                    .await
152                    .map_err(minicbor::encode::Error::Write)?;
153            }
154            IngestConnection::Tls(s) => {
155                // We have to use write_all here, because https://github.com/tokio-rs/tls/issues/41
156                s.write_all(&msg_len.to_be_bytes())
157                    .await
158                    .map_err(minicbor::encode::Error::Write)?;
159                s.write_all(&msg_buf)
160                    .await
161                    .map_err(minicbor::encode::Error::Write)?;
162            }
163        }
164
165        Ok(())
166    }
167
168    async fn read_msg(&mut self) -> Result<IngestResponse, IngestError> {
169        match self {
170            IngestConnection::Tcp(s) => {
171                let msg_len = s.read_u32().await?; // yes, this is big-endian
172                let mut msg_buf = vec![0u8; msg_len as usize];
173                s.read_exact(msg_buf.as_mut_slice()).await?;
174
175                Ok(minicbor::decode::<IngestResponse>(&msg_buf)?)
176            }
177            IngestConnection::Tls(s) => {
178                let msg_len = s.read_u32().await?; // yes, this is big-endian
179                let mut msg_buf = vec![0u8; msg_len as usize];
180                s.read_exact(msg_buf.as_mut_slice()).await?;
181
182                Ok(minicbor::decode::<IngestResponse>(&msg_buf)?)
183            }
184        }
185    }
186}
187
188impl IngestClient<UnauthenticatedState> {
189    /// Create a new ingest client.
190    pub async fn connect(
191        endpoint: &Url,
192        allow_insecure_tls: bool,
193    ) -> Result<IngestClient<UnauthenticatedState>, IngestClientInitializationError> {
194        let connection = IngestConnection::connect(endpoint, allow_insecure_tls).await?;
195        let common = IngestClientCommon::new(Duration::from_secs(1), connection);
196
197        Ok(IngestClient {
198            state: UnauthenticatedState {},
199            common,
200        })
201    }
202
203    /// Create a new ingest client.
204    pub async fn connect_with_timeout(
205        endpoint: &Url,
206        allow_insecure_tls: bool,
207        timeout: Duration,
208    ) -> Result<IngestClient<UnauthenticatedState>, IngestClientInitializationError> {
209        let connection = IngestConnection::connect(endpoint, allow_insecure_tls).await?;
210        let common = IngestClientCommon::new(timeout, connection);
211
212        Ok(IngestClient {
213            state: UnauthenticatedState {},
214            common,
215        })
216    }
217
218    pub async fn authenticate(
219        mut self,
220        token: Vec<u8>,
221    ) -> Result<IngestClient<ReadyState>, IngestError> {
222        let resp = self
223            .common
224            .send_recv(&IngestMessage::AuthRequest { token })
225            .await?;
226
227        match resp {
228            IngestResponse::AuthResponse { ok, message } => {
229                if ok {
230                    Ok(IngestClient {
231                        state: ReadyState {},
232                        common: self.common,
233                    })
234                } else {
235                    Err(IngestError::AuthenticationError {
236                        message,
237                        client: Box::new(self),
238                    })
239                }
240            }
241            _ => Err(IngestError::ProtocolError(
242                "Invalid response received in the 'Unauthenticated' state.",
243            )),
244        }
245    }
246}
247
248impl IngestClient<ReadyState> {
249    /// Create a fully authorized client connection, using the
250    /// standard config file location and environment variables.
251    pub async fn connect_with_standard_config(
252        timeout: Duration,
253        manually_provided_config_path: Option<PathBuf>,
254        manually_provided_auth_token: Option<PathBuf>,
255    ) -> Result<IngestClient<ReadyState>, IngestError> {
256        let (config, auth_token) = modality_reflector_config::resolve::load_config_and_auth_token(
257            manually_provided_config_path,
258            manually_provided_auth_token,
259        )
260        .map_err(IngestError::LoadConfigError)?;
261
262        let mut endpoint = None;
263        let mut allow_insecure_tls = false;
264        if let Some(ingest) = config.ingest {
265            allow_insecure_tls = ingest.allow_insecure_tls;
266            endpoint = ingest.protocol_parent_url;
267        };
268
269        let endpoint =
270            endpoint.unwrap_or_else(|| Url::parse("modality-ingest://127.0.0.1").unwrap());
271
272        let client = IngestClient::<UnauthenticatedState>::connect_with_timeout(
273            &endpoint,
274            allow_insecure_tls,
275            timeout,
276        )
277        .await?;
278
279        client.authenticate(auth_token.into()).await
280    }
281
282    pub async fn open_timeline(
283        mut self,
284        id: TimelineId,
285    ) -> Result<IngestClient<BoundTimelineState>, IngestError> {
286        self.common
287            .send(&IngestMessage::OpenTimeline { id })
288            .await?;
289
290        Ok(IngestClient {
291            state: BoundTimelineState { timeline_id: id },
292            common: self.common,
293        })
294    }
295
296    pub async fn declare_attr_key(
297        &mut self,
298        key_name: String,
299    ) -> Result<InternedAttrKey, IngestError> {
300        self.common.declare_attr_key(key_name).await
301    }
302}
303
304impl IngestClient<BoundTimelineState> {
305    pub fn bound_timeline(&self) -> TimelineId {
306        self.state.timeline_id
307    }
308
309    pub async fn open_timeline(&mut self, id: TimelineId) -> Result<(), IngestError> {
310        self.common
311            .send(&IngestMessage::OpenTimeline { id })
312            .await?;
313        self.state.timeline_id = id;
314        Ok(())
315    }
316
317    /// This doesn't change the connection state, but it does require you to open_timeline again
318    /// before you can do anything else.
319    pub fn close_timeline(self) -> IngestClient<ReadyState> {
320        IngestClient {
321            state: ReadyState {},
322            common: self.common,
323        }
324    }
325
326    pub async fn declare_attr_key(
327        &mut self,
328        key_name: String,
329    ) -> Result<InternedAttrKey, IngestError> {
330        self.common.declare_attr_key(key_name).await
331    }
332
333    pub async fn timeline_metadata(
334        &mut self,
335        attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
336    ) -> Result<(), IngestError> {
337        self.common.timeline_metadata(attrs).await
338    }
339
340    pub async fn event(
341        &mut self,
342        ordering: u128,
343        attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
344    ) -> Result<(), IngestError> {
345        self.common.event(ordering, attrs).await
346    }
347
348    // TODO make a blocking_flush as well, good for tests
349    pub async fn flush(&mut self) -> Result<(), IngestError> {
350        self.common.flush().await
351    }
352
353    pub async fn status(&mut self) -> Result<IngestStatus, IngestError> {
354        let resp = self
355            .common
356            .send_recv(&IngestMessage::IngestStatusRequest {})
357            .await?;
358
359        match resp {
360            IngestResponse::IngestStatusResponse {
361                current_timeline,
362                events_received,
363                events_written,
364                events_pending,
365            } => Ok(IngestStatus {
366                current_timeline,
367                events_received,
368                events_written,
369                events_pending,
370            }),
371            _ => Err(IngestError::ProtocolError(
372                "Invalid status response recieved",
373            )),
374        }
375    }
376}
377
378impl IngestClientCommon {
379    pub async fn timeline_metadata(
380        &mut self,
381        attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
382    ) -> Result<(), IngestError> {
383        let packed_attrs = PackedAttrKvs(attrs.into_iter().collect());
384
385        self.send(&IngestMessage::TimelineMetadata {
386            attrs: packed_attrs,
387        })
388        .await?;
389        Ok(())
390    }
391
392    pub async fn event(
393        &mut self,
394        ordering: u128,
395        attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
396    ) -> Result<(), IngestError> {
397        let packed_attrs = PackedAttrKvs(attrs.into_iter().collect());
398
399        let be_ordering = ordering.to_be_bytes();
400        let mut i = 0;
401        while i < 15 {
402            if be_ordering[i] != 0x00 {
403                break;
404            }
405            i += 1;
406        }
407        let compact_be_ordering = be_ordering[i..16].to_vec();
408
409        self.send(&IngestMessage::Event {
410            be_ordering: compact_be_ordering,
411            attrs: packed_attrs,
412        })
413        .await?;
414
415        Ok(())
416    }
417
418    // TODO make a blocking_flush as well, good for tests
419    pub async fn flush(&mut self) -> Result<(), IngestError> {
420        self.send(&IngestMessage::Flush {}).await?;
421
422        Ok(())
423    }
424}
425
426pub struct IngestStatus {
427    pub current_timeline: Option<TimelineId>,
428    pub events_received: u64,
429    pub events_written: u64,
430    pub events_pending: u64,
431}
432
433#[derive(Debug, Error)]
434pub enum IngestClientInitializationError {
435    #[error("DNS Error: No IPs")]
436    NoIps,
437
438    #[error("Socket initialization error")]
439    SocketInit(#[source] std::io::Error),
440
441    #[error("Socket connection error. Remote address: {}", remote_addr)]
442    SocketConnection {
443        #[source]
444        error: std::io::Error,
445        remote_addr: SocketAddr,
446    },
447
448    #[error("TLS Error")]
449    Tls(#[from] native_tls::Error),
450
451    #[error("Client local address parsing failed.")]
452    ClientLocalAddrParse(#[from] std::net::AddrParseError),
453
454    #[error("Error parsing endpoint")]
455    ParseIngestEndpoint(#[from] ParseIngestEndpointError),
456}
457
458#[derive(Error)]
459pub enum IngestError {
460    #[error(transparent)]
461    LoadConfigError(Box<dyn std::error::Error + Send + Sync>),
462
463    #[error("Authentication Error: {message:?}")]
464    AuthenticationError {
465        message: Option<String>,
466        client: Box<IngestClient<UnauthenticatedState>>,
467    },
468
469    #[error("Protocol Error: {0}")]
470    ProtocolError(&'static str),
471
472    #[error("Marshalling Error (Write)")]
473    CborEncode(#[from] minicbor::encode::Error<std::io::Error>),
474
475    #[error("Marshalling Error (Read)")]
476    CborDecode(#[from] minicbor::decode::Error),
477
478    #[error("Timeout")]
479    Timeout(#[from] tokio::time::error::Elapsed),
480
481    #[error("Event attr keys must begin with 'event.', and timeline attr keys must begin with 'timeline.'")]
482    AttrKeyNaming,
483
484    #[error(transparent)]
485    IngestClientInitializationError(#[from] IngestClientInitializationError),
486
487    #[error("IO")]
488    Io(#[from] std::io::Error),
489}
490
491// Manual impl so we can skip the embedded 'client'
492impl std::fmt::Debug for IngestError {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        match self {
495            Self::LoadConfigError(arg0) => f.debug_tuple("LoadConfigError").field(arg0).finish(),
496            Self::AuthenticationError { message, .. } => f
497                .debug_struct("AuthenticationError")
498                .field("message", message)
499                .finish(),
500            Self::ProtocolError(arg0) => f.debug_tuple("ProtocolError").field(arg0).finish(),
501            Self::CborEncode(arg0) => f.debug_tuple("CborEncode").field(arg0).finish(),
502            Self::CborDecode(arg0) => f.debug_tuple("CborDecode").field(arg0).finish(),
503            Self::Timeout(arg0) => f.debug_tuple("Timeout").field(arg0).finish(),
504            Self::AttrKeyNaming => write!(f, "AttrKeyNaming"),
505            Self::IngestClientInitializationError(arg0) => f
506                .debug_tuple("IngestClientInitializationError")
507                .field(arg0)
508                .finish(),
509            Self::Io(arg0) => f.debug_tuple("Io").field(arg0).finish(),
510        }
511    }
512}
513
514pub const MODALITY_STORAGE_SERVICE_PORT_DEFAULT: u16 = 14182;
515pub const MODALITY_STORAGE_SERVICE_TLS_PORT_DEFAULT: u16 = 14184;
516pub const MODALITY_INGEST_URL_SCHEME: &str = "modality-ingest";
517pub const MODALITY_INGEST_TLS_URL_SCHEME: &str = "modality-ingest-tls";
518
519struct IngestEndpoint {
520    cert_domain: String,
521    addrs: Vec<SocketAddr>,
522    tls_mode: Option<TlsMode>,
523}
524
525impl IngestEndpoint {
526    async fn parse_and_resolve(
527        url: &Url,
528        allow_insecure_tls: bool,
529    ) -> Result<IngestEndpoint, ParseIngestEndpointError> {
530        let host = match url.host() {
531            Some(h) => h,
532            None => return Err(ParseIngestEndpointError::MissingHost),
533        };
534
535        let is_tls = match url.scheme() {
536            MODALITY_INGEST_URL_SCHEME => false,
537            MODALITY_INGEST_TLS_URL_SCHEME => true,
538            s => return Err(ParseIngestEndpointError::InvalidScheme(s.to_string())),
539        };
540        let port = match url.port() {
541            Some(p) => p,
542            _ => {
543                if is_tls {
544                    MODALITY_STORAGE_SERVICE_TLS_PORT_DEFAULT
545                } else {
546                    MODALITY_STORAGE_SERVICE_PORT_DEFAULT
547                }
548            }
549        };
550
551        let addrs = match host {
552            url::Host::Domain(domain) => tokio::net::lookup_host((domain, port)).await?.collect(),
553            url::Host::Ipv4(addr) => vec![SocketAddr::from((addr, port))],
554            url::Host::Ipv6(addr) => vec![SocketAddr::from((addr, port))],
555        };
556
557        let tls_mode = match (is_tls, allow_insecure_tls) {
558            (true, true) => Some(TlsMode::Insecure),
559            (true, false) => Some(TlsMode::Secure),
560            (false, _) => None,
561        };
562
563        Ok(IngestEndpoint {
564            cert_domain: host.to_string(),
565            addrs,
566            tls_mode,
567        })
568    }
569}
570
571#[derive(Debug, Error)]
572pub enum ParseIngestEndpointError {
573    #[error("Url most contain a host")]
574    MissingHost,
575
576    // TODO update with the real thing
577    #[error("Invalid URL scheme '{0}'. Must be one of 'modality-ingest' or 'modality-ingest-tls'")]
578    InvalidScheme(String),
579
580    #[error("IO Error")]
581    Io(#[from] std::io::Error),
582}