1use modality_api::types::{AttrKey, AttrVal, TimelineId};
2use modality_ingest_protocol::{IngestMessage, IngestResponse, InternedAttrKey, PackedAttrKvs};
3use std::{net::SocketAddr, path::PathBuf, time::Duration};
4use thiserror::Error;
5use tokio::{
6 io::{AsyncReadExt, AsyncWriteExt},
7 net::{TcpSocket, TcpStream},
8 time::timeout,
9};
10use tokio_native_tls::TlsStream;
11use url::Url;
12
13pub struct IngestClient<S> {
14 #[allow(unused)]
15 pub(crate) state: S,
16 pub(crate) common: IngestClientCommon,
17}
18
19pub struct UnauthenticatedState {}
20pub struct ReadyState {}
21pub struct BoundTimelineState {
22 pub(crate) timeline_id: TimelineId,
23}
24
25#[doc(hidden)]
27pub struct IngestClientCommon {
28 pub timeout: Duration,
29 connection: IngestConnection,
30 next_id: u32,
31}
32
33impl IngestClientCommon {
34 #[doc(hidden)]
35 pub fn new(timeout: Duration, connection: IngestConnection) -> Self {
36 IngestClientCommon {
37 timeout,
38 connection,
39 next_id: 0,
40 }
41 }
42
43 #[doc(hidden)]
45 pub async fn send_recv(&mut self, msg: &IngestMessage) -> Result<IngestResponse, IngestError> {
46 self.connection.write_msg(msg).await?;
47 timeout(self.timeout, self.connection.read_msg()).await?
48 }
49
50 #[doc(hidden)]
52 pub async fn send(&mut self, msg: &IngestMessage) -> Result<(), IngestError> {
53 self.connection.write_msg(msg).await
54 }
55
56 pub(crate) async fn declare_attr_key<K: Into<AttrKey>>(
57 &mut self,
58 key_name: K,
59 ) -> Result<InternedAttrKey, IngestError> {
60 let key_name = key_name.into();
61
62 if !(key_name.as_ref().starts_with("timeline.") || key_name.as_ref().starts_with("event."))
63 {
64 return Err(IngestError::AttrKeyNaming);
65 }
66
67 let wire_id = self.next_id;
68 self.next_id += 1;
69 let wire_id = wire_id.into();
70
71 self.send(&IngestMessage::DeclareAttrKey {
72 name: key_name.into(),
73 wire_id,
74 })
75 .await?;
76
77 Ok(wire_id)
78 }
79}
80
81#[derive(Copy, Clone)]
82pub enum TlsMode {
83 Secure,
84 Insecure,
85}
86
87pub enum IngestConnection {
88 Tcp(TcpStream),
89 Tls(TlsStream<TcpStream>),
90}
91
92impl IngestConnection {
93 pub async fn connect(
94 endpoint: &Url,
95 allow_insecure_tls: bool,
96 ) -> Result<IngestConnection, IngestClientInitializationError> {
97 let endpoint = IngestEndpoint::parse_and_resolve(endpoint, allow_insecure_tls).await?;
98
99 let remote_addr = endpoint
101 .addrs
102 .into_iter()
103 .next()
104 .ok_or(IngestClientInitializationError::NoIps)?;
105
106 let local_addr: SocketAddr = if remote_addr.is_ipv4() {
107 "0.0.0.0:0"
108 } else {
109 "[::]:0"
110 }
111 .parse()?;
112
113 let socket = if remote_addr.is_ipv4() {
114 TcpSocket::new_v4().map_err(IngestClientInitializationError::SocketInit)?
115 } else {
116 TcpSocket::new_v6().map_err(IngestClientInitializationError::SocketInit)?
117 };
118
119 socket
120 .bind(local_addr)
121 .map_err(IngestClientInitializationError::SocketInit)?;
122 let stream = socket.connect(remote_addr).await.map_err(|error| {
123 IngestClientInitializationError::SocketConnection { error, remote_addr }
124 })?;
125
126 if let Some(tls_mode) = endpoint.tls_mode {
127 let cx = native_tls::TlsConnector::builder()
128 .danger_accept_invalid_certs(match tls_mode {
129 TlsMode::Secure => false,
130 TlsMode::Insecure => true,
131 })
132 .build()?;
133 let cx = tokio_native_tls::TlsConnector::from(cx);
134 let stream = cx.connect(&endpoint.cert_domain, stream).await?;
135 Ok(IngestConnection::Tls(stream))
136 } else {
137 Ok(IngestConnection::Tcp(stream))
138 }
139 }
140
141 async fn write_msg(&mut self, msg: &IngestMessage) -> Result<(), IngestError> {
142 let msg_buf = minicbor::to_vec(msg)?;
143 let msg_len = msg_buf.len() as u32;
144
145 match self {
146 IngestConnection::Tcp(s) => {
147 s.write_all(&msg_len.to_be_bytes())
148 .await
149 .map_err(minicbor::encode::Error::Write)?;
150 s.write_all(&msg_buf)
151 .await
152 .map_err(minicbor::encode::Error::Write)?;
153 }
154 IngestConnection::Tls(s) => {
155 s.write_all(&msg_len.to_be_bytes())
157 .await
158 .map_err(minicbor::encode::Error::Write)?;
159 s.write_all(&msg_buf)
160 .await
161 .map_err(minicbor::encode::Error::Write)?;
162 }
163 }
164
165 Ok(())
166 }
167
168 async fn read_msg(&mut self) -> Result<IngestResponse, IngestError> {
169 match self {
170 IngestConnection::Tcp(s) => {
171 let msg_len = s.read_u32().await?; let mut msg_buf = vec![0u8; msg_len as usize];
173 s.read_exact(msg_buf.as_mut_slice()).await?;
174
175 Ok(minicbor::decode::<IngestResponse>(&msg_buf)?)
176 }
177 IngestConnection::Tls(s) => {
178 let msg_len = s.read_u32().await?; let mut msg_buf = vec![0u8; msg_len as usize];
180 s.read_exact(msg_buf.as_mut_slice()).await?;
181
182 Ok(minicbor::decode::<IngestResponse>(&msg_buf)?)
183 }
184 }
185 }
186}
187
188impl IngestClient<UnauthenticatedState> {
189 pub async fn connect(
191 endpoint: &Url,
192 allow_insecure_tls: bool,
193 ) -> Result<IngestClient<UnauthenticatedState>, IngestClientInitializationError> {
194 let connection = IngestConnection::connect(endpoint, allow_insecure_tls).await?;
195 let common = IngestClientCommon::new(Duration::from_secs(1), connection);
196
197 Ok(IngestClient {
198 state: UnauthenticatedState {},
199 common,
200 })
201 }
202
203 pub async fn connect_with_timeout(
205 endpoint: &Url,
206 allow_insecure_tls: bool,
207 timeout: Duration,
208 ) -> Result<IngestClient<UnauthenticatedState>, IngestClientInitializationError> {
209 let connection = IngestConnection::connect(endpoint, allow_insecure_tls).await?;
210 let common = IngestClientCommon::new(timeout, connection);
211
212 Ok(IngestClient {
213 state: UnauthenticatedState {},
214 common,
215 })
216 }
217
218 pub async fn authenticate(
219 mut self,
220 token: Vec<u8>,
221 ) -> Result<IngestClient<ReadyState>, IngestError> {
222 let resp = self
223 .common
224 .send_recv(&IngestMessage::AuthRequest { token })
225 .await?;
226
227 match resp {
228 IngestResponse::AuthResponse { ok, message } => {
229 if ok {
230 Ok(IngestClient {
231 state: ReadyState {},
232 common: self.common,
233 })
234 } else {
235 Err(IngestError::AuthenticationError {
236 message,
237 client: Box::new(self),
238 })
239 }
240 }
241 _ => Err(IngestError::ProtocolError(
242 "Invalid response received in the 'Unauthenticated' state.",
243 )),
244 }
245 }
246}
247
248impl IngestClient<ReadyState> {
249 pub async fn connect_with_standard_config(
252 timeout: Duration,
253 manually_provided_config_path: Option<PathBuf>,
254 manually_provided_auth_token: Option<PathBuf>,
255 ) -> Result<IngestClient<ReadyState>, IngestError> {
256 let (config, auth_token) = modality_reflector_config::resolve::load_config_and_auth_token(
257 manually_provided_config_path,
258 manually_provided_auth_token,
259 )
260 .map_err(IngestError::LoadConfigError)?;
261
262 let mut endpoint = None;
263 let mut allow_insecure_tls = false;
264 if let Some(ingest) = config.ingest {
265 allow_insecure_tls = ingest.allow_insecure_tls;
266 endpoint = ingest.protocol_parent_url;
267 };
268
269 let endpoint =
270 endpoint.unwrap_or_else(|| Url::parse("modality-ingest://127.0.0.1").unwrap());
271
272 let client = IngestClient::<UnauthenticatedState>::connect_with_timeout(
273 &endpoint,
274 allow_insecure_tls,
275 timeout,
276 )
277 .await?;
278
279 client.authenticate(auth_token.into()).await
280 }
281
282 pub async fn open_timeline(
283 mut self,
284 id: TimelineId,
285 ) -> Result<IngestClient<BoundTimelineState>, IngestError> {
286 self.common
287 .send(&IngestMessage::OpenTimeline { id })
288 .await?;
289
290 Ok(IngestClient {
291 state: BoundTimelineState { timeline_id: id },
292 common: self.common,
293 })
294 }
295
296 pub async fn declare_attr_key(
297 &mut self,
298 key_name: String,
299 ) -> Result<InternedAttrKey, IngestError> {
300 self.common.declare_attr_key(key_name).await
301 }
302}
303
304impl IngestClient<BoundTimelineState> {
305 pub fn bound_timeline(&self) -> TimelineId {
306 self.state.timeline_id
307 }
308
309 pub async fn open_timeline(&mut self, id: TimelineId) -> Result<(), IngestError> {
310 self.common
311 .send(&IngestMessage::OpenTimeline { id })
312 .await?;
313 self.state.timeline_id = id;
314 Ok(())
315 }
316
317 pub fn close_timeline(self) -> IngestClient<ReadyState> {
320 IngestClient {
321 state: ReadyState {},
322 common: self.common,
323 }
324 }
325
326 pub async fn declare_attr_key(
327 &mut self,
328 key_name: String,
329 ) -> Result<InternedAttrKey, IngestError> {
330 self.common.declare_attr_key(key_name).await
331 }
332
333 pub async fn timeline_metadata(
334 &mut self,
335 attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
336 ) -> Result<(), IngestError> {
337 self.common.timeline_metadata(attrs).await
338 }
339
340 pub async fn event(
341 &mut self,
342 ordering: u128,
343 attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
344 ) -> Result<(), IngestError> {
345 self.common.event(ordering, attrs).await
346 }
347
348 pub async fn flush(&mut self) -> Result<(), IngestError> {
350 self.common.flush().await
351 }
352
353 pub async fn status(&mut self) -> Result<IngestStatus, IngestError> {
354 let resp = self
355 .common
356 .send_recv(&IngestMessage::IngestStatusRequest {})
357 .await?;
358
359 match resp {
360 IngestResponse::IngestStatusResponse {
361 current_timeline,
362 events_received,
363 events_written,
364 events_pending,
365 } => Ok(IngestStatus {
366 current_timeline,
367 events_received,
368 events_written,
369 events_pending,
370 }),
371 _ => Err(IngestError::ProtocolError(
372 "Invalid status response recieved",
373 )),
374 }
375 }
376}
377
378impl IngestClientCommon {
379 pub async fn timeline_metadata(
380 &mut self,
381 attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
382 ) -> Result<(), IngestError> {
383 let packed_attrs = PackedAttrKvs(attrs.into_iter().collect());
384
385 self.send(&IngestMessage::TimelineMetadata {
386 attrs: packed_attrs,
387 })
388 .await?;
389 Ok(())
390 }
391
392 pub async fn event(
393 &mut self,
394 ordering: u128,
395 attrs: impl IntoIterator<Item = (InternedAttrKey, AttrVal)>,
396 ) -> Result<(), IngestError> {
397 let packed_attrs = PackedAttrKvs(attrs.into_iter().collect());
398
399 let be_ordering = ordering.to_be_bytes();
400 let mut i = 0;
401 while i < 15 {
402 if be_ordering[i] != 0x00 {
403 break;
404 }
405 i += 1;
406 }
407 let compact_be_ordering = be_ordering[i..16].to_vec();
408
409 self.send(&IngestMessage::Event {
410 be_ordering: compact_be_ordering,
411 attrs: packed_attrs,
412 })
413 .await?;
414
415 Ok(())
416 }
417
418 pub async fn flush(&mut self) -> Result<(), IngestError> {
420 self.send(&IngestMessage::Flush {}).await?;
421
422 Ok(())
423 }
424}
425
426pub struct IngestStatus {
427 pub current_timeline: Option<TimelineId>,
428 pub events_received: u64,
429 pub events_written: u64,
430 pub events_pending: u64,
431}
432
433#[derive(Debug, Error)]
434pub enum IngestClientInitializationError {
435 #[error("DNS Error: No IPs")]
436 NoIps,
437
438 #[error("Socket initialization error")]
439 SocketInit(#[source] std::io::Error),
440
441 #[error("Socket connection error. Remote address: {}", remote_addr)]
442 SocketConnection {
443 #[source]
444 error: std::io::Error,
445 remote_addr: SocketAddr,
446 },
447
448 #[error("TLS Error")]
449 Tls(#[from] native_tls::Error),
450
451 #[error("Client local address parsing failed.")]
452 ClientLocalAddrParse(#[from] std::net::AddrParseError),
453
454 #[error("Error parsing endpoint")]
455 ParseIngestEndpoint(#[from] ParseIngestEndpointError),
456}
457
458#[derive(Error)]
459pub enum IngestError {
460 #[error(transparent)]
461 LoadConfigError(Box<dyn std::error::Error + Send + Sync>),
462
463 #[error("Authentication Error: {message:?}")]
464 AuthenticationError {
465 message: Option<String>,
466 client: Box<IngestClient<UnauthenticatedState>>,
467 },
468
469 #[error("Protocol Error: {0}")]
470 ProtocolError(&'static str),
471
472 #[error("Marshalling Error (Write)")]
473 CborEncode(#[from] minicbor::encode::Error<std::io::Error>),
474
475 #[error("Marshalling Error (Read)")]
476 CborDecode(#[from] minicbor::decode::Error),
477
478 #[error("Timeout")]
479 Timeout(#[from] tokio::time::error::Elapsed),
480
481 #[error("Event attr keys must begin with 'event.', and timeline attr keys must begin with 'timeline.'")]
482 AttrKeyNaming,
483
484 #[error(transparent)]
485 IngestClientInitializationError(#[from] IngestClientInitializationError),
486
487 #[error("IO")]
488 Io(#[from] std::io::Error),
489}
490
491impl std::fmt::Debug for IngestError {
493 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494 match self {
495 Self::LoadConfigError(arg0) => f.debug_tuple("LoadConfigError").field(arg0).finish(),
496 Self::AuthenticationError { message, .. } => f
497 .debug_struct("AuthenticationError")
498 .field("message", message)
499 .finish(),
500 Self::ProtocolError(arg0) => f.debug_tuple("ProtocolError").field(arg0).finish(),
501 Self::CborEncode(arg0) => f.debug_tuple("CborEncode").field(arg0).finish(),
502 Self::CborDecode(arg0) => f.debug_tuple("CborDecode").field(arg0).finish(),
503 Self::Timeout(arg0) => f.debug_tuple("Timeout").field(arg0).finish(),
504 Self::AttrKeyNaming => write!(f, "AttrKeyNaming"),
505 Self::IngestClientInitializationError(arg0) => f
506 .debug_tuple("IngestClientInitializationError")
507 .field(arg0)
508 .finish(),
509 Self::Io(arg0) => f.debug_tuple("Io").field(arg0).finish(),
510 }
511 }
512}
513
514pub const MODALITY_STORAGE_SERVICE_PORT_DEFAULT: u16 = 14182;
515pub const MODALITY_STORAGE_SERVICE_TLS_PORT_DEFAULT: u16 = 14184;
516pub const MODALITY_INGEST_URL_SCHEME: &str = "modality-ingest";
517pub const MODALITY_INGEST_TLS_URL_SCHEME: &str = "modality-ingest-tls";
518
519struct IngestEndpoint {
520 cert_domain: String,
521 addrs: Vec<SocketAddr>,
522 tls_mode: Option<TlsMode>,
523}
524
525impl IngestEndpoint {
526 async fn parse_and_resolve(
527 url: &Url,
528 allow_insecure_tls: bool,
529 ) -> Result<IngestEndpoint, ParseIngestEndpointError> {
530 let host = match url.host() {
531 Some(h) => h,
532 None => return Err(ParseIngestEndpointError::MissingHost),
533 };
534
535 let is_tls = match url.scheme() {
536 MODALITY_INGEST_URL_SCHEME => false,
537 MODALITY_INGEST_TLS_URL_SCHEME => true,
538 s => return Err(ParseIngestEndpointError::InvalidScheme(s.to_string())),
539 };
540 let port = match url.port() {
541 Some(p) => p,
542 _ => {
543 if is_tls {
544 MODALITY_STORAGE_SERVICE_TLS_PORT_DEFAULT
545 } else {
546 MODALITY_STORAGE_SERVICE_PORT_DEFAULT
547 }
548 }
549 };
550
551 let addrs = match host {
552 url::Host::Domain(domain) => tokio::net::lookup_host((domain, port)).await?.collect(),
553 url::Host::Ipv4(addr) => vec![SocketAddr::from((addr, port))],
554 url::Host::Ipv6(addr) => vec![SocketAddr::from((addr, port))],
555 };
556
557 let tls_mode = match (is_tls, allow_insecure_tls) {
558 (true, true) => Some(TlsMode::Insecure),
559 (true, false) => Some(TlsMode::Secure),
560 (false, _) => None,
561 };
562
563 Ok(IngestEndpoint {
564 cert_domain: host.to_string(),
565 addrs,
566 tls_mode,
567 })
568 }
569}
570
571#[derive(Debug, Error)]
572pub enum ParseIngestEndpointError {
573 #[error("Url most contain a host")]
574 MissingHost,
575
576 #[error("Invalid URL scheme '{0}'. Must be one of 'modality-ingest' or 'modality-ingest-tls'")]
578 InvalidScheme(String),
579
580 #[error("IO Error")]
581 Io(#[from] std::io::Error),
582}