1use std::{
103 collections::BTreeMap,
104 error, fmt, io,
105 sync::{Arc, LazyLock},
106 time::SystemTime,
107};
108
109use backoff::{ExponentialBackoffBuilder, future::retry};
110use bytes::Bytes;
111use deadpool::managed::{self, BuildError, Object, PoolError};
112use opentelemetry::{
113 InstrumentationScope, KeyValue, global,
114 metrics::{Counter, Gauge, Histogram, Meter},
115};
116use opentelemetry_semantic_conventions::SCHEMA_URL;
117use rama::{Context, Layer, Service};
118use tansu_sans_io::{ApiKey, ApiVersionsRequest, Body, Frame, Header, Request, RootMessageMeta};
119use tansu_service::{FrameBytesLayer, FrameBytesService, host_port};
120use tokio::{
121 io::{AsyncReadExt as _, AsyncWriteExt as _},
122 net::TcpStream,
123 task::JoinError,
124 time::Duration,
125};
126use tracing::{Instrument, Level, debug, span};
127use tracing_subscriber::filter::ParseError;
128use url::Url;
129
130#[derive(thiserror::Error, Clone, Debug)]
132pub enum Error {
133 DeadPoolBuild(#[from] BuildError),
134 Io(Arc<io::Error>),
135 Join(Arc<JoinError>),
136 Message(String),
137 ParseFilter(Arc<ParseError>),
138 ParseUrl(#[from] url::ParseError),
139 Pool(Arc<Box<dyn error::Error + Send + Sync>>),
140 Protocol(#[from] tansu_sans_io::Error),
141 Service(#[from] tansu_service::Error),
142 UnknownApiKey(i16),
143 UnknownHost(Url),
144}
145
146impl fmt::Display for Error {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 write!(f, "{self:?}")
149 }
150}
151
152impl From<JoinError> for Error {
153 fn from(value: JoinError) -> Self {
154 Self::Join(Arc::new(value))
155 }
156}
157
158impl<E> From<PoolError<E>> for Error
159where
160 E: error::Error + Send + Sync + 'static,
161{
162 fn from(value: PoolError<E>) -> Self {
163 Self::Pool(Arc::new(Box::new(value)))
164 }
165}
166
167impl From<io::Error> for Error {
168 fn from(value: io::Error) -> Self {
169 Self::Io(Arc::new(value))
170 }
171}
172
173impl From<ParseError> for Error {
174 fn from(value: ParseError) -> Self {
175 Self::ParseFilter(Arc::new(value))
176 }
177}
178
179pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
180 global::meter_with_scope(
181 InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
182 .with_version(env!("CARGO_PKG_VERSION"))
183 .with_schema_url(SCHEMA_URL)
184 .build(),
185 )
186});
187
188#[derive(Debug)]
190pub struct Connection {
191 stream: TcpStream,
192 correlation_id: i32,
193}
194
195#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
197pub struct ConnectionManager {
198 broker: Url,
199 client_id: Option<String>,
200 versions: BTreeMap<i16, i16>,
201}
202
203impl ConnectionManager {
204 pub fn builder(broker: Url) -> Builder {
206 Builder::broker(broker)
207 }
208
209 pub fn client_id(&self) -> Option<String> {
211 self.client_id.clone()
212 }
213
214 pub fn api_version(&self, api_key: i16) -> Result<i16, Error> {
216 self.versions
217 .get(&api_key)
218 .copied()
219 .ok_or(Error::UnknownApiKey(api_key))
220 }
221}
222
223const INITIAL_CONNECTION_TIMEOUT_MILLIS: u64 = 30_000;
224
225impl managed::Manager for ConnectionManager {
226 type Type = Connection;
227 type Error = Error;
228
229 async fn create(&self) -> Result<Self::Type, Self::Error> {
230 debug!(%self.broker);
231
232 let attributes = [KeyValue::new("broker", self.broker.to_string())];
233 let start = SystemTime::now();
234
235 let addr = host_port(self.broker.clone()).await?;
236
237 let backoff = ExponentialBackoffBuilder::new()
238 .with_max_elapsed_time(Some(Duration::from_millis(
239 INITIAL_CONNECTION_TIMEOUT_MILLIS,
240 )))
241 .build();
242 retry(backoff, || async {
243 Ok(TcpStream::connect(addr)
244 .await
245 .inspect(|_| {
246 TCP_CONNECT_DURATION.record(
247 start
248 .elapsed()
249 .map_or(0, |duration| duration.as_millis() as u64),
250 &attributes,
251 )
252 })
253 .inspect_err(|err| {
254 debug!(broker = %self.broker, ?err, elapsed = start.elapsed().map_or(0, |duration| duration.as_millis() as u64));
255 TCP_CONNECT_ERRORS.add(1, &attributes);
256 })
257 .map(|stream| Connection {
258 stream,
259 correlation_id: 0,
260 })?)
261 })
262 .await
263 .map_err(Into::into)
264 }
265
266 async fn recycle(
267 &self,
268 obj: &mut Self::Type,
269 metrics: &managed::Metrics,
270 ) -> managed::RecycleResult<Self::Error> {
271 debug!(?obj, ?metrics);
272
273 Ok(())
274 }
275}
276
277pub type Pool = managed::Pool<ConnectionManager>;
279
280fn status_update(pool: &Pool) {
281 let status = pool.status();
282 POOL_AVAILABLE.record(status.available as u64, &[]);
283 POOL_CURRENT_SIZE.record(status.size as u64, &[]);
284 POOL_MAX_SIZE.record(status.max_size as u64, &[]);
285 POOL_WAITING.record(status.waiting as u64, &[]);
286}
287
288#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
290pub struct Builder {
291 broker: Url,
292 client_id: Option<String>,
293}
294
295impl Builder {
296 pub fn broker(broker: Url) -> Self {
298 Self {
299 broker,
300 client_id: None,
301 }
302 }
303
304 pub fn client_id(self, client_id: Option<String>) -> Self {
306 Self { client_id, ..self }
307 }
308
309 async fn bootstrap(&self) -> Result<BTreeMap<i16, i16>, Error> {
311 let versions = BTreeMap::from([(ApiVersionsRequest::KEY, 0)]);
314
315 let req = ApiVersionsRequest::default()
316 .client_software_name(Some(env!("CARGO_PKG_NAME").into()))
317 .client_software_version(Some(env!("CARGO_PKG_VERSION").into()));
318
319 let client = Pool::builder(ConnectionManager {
320 broker: self.broker.clone(),
321 client_id: self.client_id.clone(),
322 versions,
323 })
324 .build()
325 .map(Client::new)?;
326
327 let supported = RootMessageMeta::messages().requests();
328
329 client.call(req).await.map(|response| {
330 response
331 .api_keys
332 .unwrap_or_default()
333 .into_iter()
334 .filter_map(|api| {
335 supported.get(&api.api_key).and_then(|supported| {
336 if api.min_version >= supported.version.valid.start {
337 Some((
338 api.api_key,
339 api.max_version.min(supported.version.valid.end),
340 ))
341 } else {
342 None
343 }
344 })
345 })
346 .collect()
347 })
348 }
349
350 pub async fn build(self) -> Result<Pool, Error> {
352 self.bootstrap().await.and_then(|versions| {
353 Pool::builder(ConnectionManager {
354 broker: self.broker,
355 client_id: self.client_id,
356 versions,
357 })
358 .build()
359 .map_err(Into::into)
360 })
361 }
362}
363
364#[derive(Clone, Debug)]
366pub struct FramePoolLayer {
367 pool: Pool,
368}
369
370impl FramePoolLayer {
371 pub fn new(pool: Pool) -> Self {
372 Self { pool }
373 }
374}
375
376impl<S> Layer<S> for FramePoolLayer {
377 type Service = FramePoolService<S>;
378
379 fn layer(&self, inner: S) -> Self::Service {
380 FramePoolService {
381 pool: self.pool.clone(),
382 inner,
383 }
384 }
385}
386
387#[derive(Clone, Debug)]
389pub struct FramePoolService<S> {
390 pool: Pool,
391 inner: S,
392}
393
394impl<State, S> Service<State, Frame> for FramePoolService<S>
395where
396 S: Service<Pool, Frame, Response = Frame>,
397 State: Send + Sync + 'static,
398{
399 type Response = Frame;
400 type Error = S::Error;
401
402 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
403 let (ctx, _) = ctx.swap_state(self.pool.clone());
404 self.inner.serve(ctx, req).await
405 }
406}
407
408#[derive(Clone, Debug)]
410pub struct RequestPoolLayer {
411 pool: Pool,
412}
413
414impl RequestPoolLayer {
415 pub fn new(pool: Pool) -> Self {
416 Self { pool }
417 }
418}
419
420impl<S> Layer<S> for RequestPoolLayer {
421 type Service = RequestPoolService<S>;
422
423 fn layer(&self, inner: S) -> Self::Service {
424 RequestPoolService {
425 pool: self.pool.clone(),
426 inner,
427 }
428 }
429}
430
431#[derive(Clone, Debug)]
433pub struct RequestPoolService<S> {
434 pool: Pool,
435 inner: S,
436}
437
438impl<State, S, Q> Service<State, Q> for RequestPoolService<S>
439where
440 Q: Request,
441 S: Service<Pool, Q>,
442 State: Send + Sync + 'static,
443{
444 type Response = S::Response;
445 type Error = S::Error;
446
447 async fn serve(&self, ctx: Context<State>, req: Q) -> Result<Self::Response, Self::Error> {
449 let (ctx, _) = ctx.swap_state(self.pool.clone());
450 self.inner.serve(ctx, req).await
451 }
452}
453
454#[derive(Clone, Debug)]
456pub struct Client {
457 service:
458 RequestPoolService<RequestConnectionService<FrameBytesService<BytesConnectionService>>>,
459}
460
461impl Client {
462 pub fn new(pool: Pool) -> Self {
464 let service = (
465 RequestPoolLayer::new(pool),
466 RequestConnectionLayer,
467 FrameBytesLayer,
468 )
469 .into_layer(BytesConnectionService);
470
471 Self { service }
472 }
473
474 pub async fn call<Q>(&self, req: Q) -> Result<Q::Response, Error>
476 where
477 Q: Request,
478 Error: From<<<Q as Request>::Response as TryFrom<Body>>::Error>,
479 {
480 self.service.serve(Context::default(), req).await
481 }
482}
483
484#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
486pub struct FrameConnectionLayer;
487
488impl<S> Layer<S> for FrameConnectionLayer {
489 type Service = FrameConnectionService<S>;
490
491 fn layer(&self, inner: S) -> Self::Service {
492 Self::Service { inner }
493 }
494}
495
496#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
498pub struct FrameConnectionService<S> {
499 inner: S,
500}
501
502impl<S> Service<Pool, Frame> for FrameConnectionService<S>
503where
504 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
505 S::Error: From<Error> + From<PoolError<Error>> + From<tansu_sans_io::Error>,
506{
507 type Response = Frame;
508 type Error = S::Error;
509
510 async fn serve(&self, ctx: Context<Pool>, req: Frame) -> Result<Self::Response, Self::Error> {
511 debug!(?req);
512
513 let api_key = req.api_key()?;
514 let api_version = req.api_version()?;
515 let client_id = req
516 .client_id()
517 .map(|client_id| client_id.map(|client_id| client_id.to_string()))?;
518
519 let pool = ctx.state();
520 status_update(pool);
521
522 let connection = pool.get().await?;
523 let correlation_id = connection.correlation_id;
524
525 let frame = Frame {
526 size: 0,
527 header: Header::Request {
528 api_key,
529 api_version,
530 correlation_id,
531 client_id,
532 },
533 body: req.body,
534 };
535
536 let (ctx, _) = ctx.swap_state(connection);
537
538 self.inner.serve(ctx, frame).await
539 }
540}
541
542#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
544pub struct RequestConnectionLayer;
545
546impl<S> Layer<S> for RequestConnectionLayer {
547 type Service = RequestConnectionService<S>;
548
549 fn layer(&self, inner: S) -> Self::Service {
550 Self::Service { inner }
551 }
552}
553
554#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
558pub struct RequestConnectionService<S> {
559 inner: S,
560}
561
562impl<Q, S> Service<Pool, Q> for RequestConnectionService<S>
563where
564 Q: Request,
565 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
566 S::Error: From<Error>
567 + From<PoolError<Error>>
568 + From<tansu_sans_io::Error>
569 + From<<Q::Response as TryFrom<Body>>::Error>,
570{
571 type Response = Q::Response;
572 type Error = S::Error;
573
574 async fn serve(&self, ctx: Context<Pool>, req: Q) -> Result<Self::Response, Self::Error> {
575 debug!(?req);
576 let pool = ctx.state();
577 let api_key = Q::KEY;
578 let api_version = pool.manager().api_version(api_key)?;
579 let client_id = pool.manager().client_id();
580 let connection = pool.get().await?;
581 let correlation_id = connection.correlation_id;
582
583 let frame = Frame {
584 size: 0,
585 header: Header::Request {
586 api_key,
587 api_version,
588 correlation_id,
589 client_id,
590 },
591 body: req.into(),
592 };
593
594 let (ctx, _) = ctx.swap_state(connection);
595
596 let frame = self.inner.serve(ctx, frame).await?;
597
598 Q::Response::try_from(frame.body)
599 .inspect(|response| debug!(?response))
600 .map_err(Into::into)
601 }
602}
603
604#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
606pub struct BytesConnectionService;
607
608impl BytesConnectionService {
609 async fn write(
610 &self,
611 stream: &mut TcpStream,
612 frame: Bytes,
613 attributes: &[KeyValue],
614 ) -> Result<(), Error> {
615 debug!(frame = ?&frame[..]);
616
617 let start = SystemTime::now();
618
619 stream
620 .write_all(&frame[..])
621 .await
622 .inspect(|_| {
623 TCP_SEND_DURATION.record(
624 start
625 .elapsed()
626 .map_or(0, |duration| duration.as_millis() as u64),
627 attributes,
628 );
629
630 TCP_BYTES_SENT.add(frame.len() as u64, attributes);
631 })
632 .inspect_err(|_| {
633 TCP_SEND_ERRORS.add(1, attributes);
634 })
635 .map_err(Into::into)
636 }
637
638 async fn read(&self, stream: &mut TcpStream, attributes: &[KeyValue]) -> Result<Bytes, Error> {
639 let start = SystemTime::now();
640
641 let mut size = [0u8; 4];
642 _ = stream.read_exact(&mut size).await?;
643
644 let mut buffer: Vec<u8> = vec![0u8; frame_length(size)];
645 buffer[0..size.len()].copy_from_slice(&size[..]);
646 _ = stream
647 .read_exact(&mut buffer[4..])
648 .await
649 .inspect(|_| {
650 TCP_RECEIVE_DURATION.record(
651 start
652 .elapsed()
653 .map_or(0, |duration| duration.as_millis() as u64),
654 attributes,
655 );
656
657 TCP_BYTES_RECEIVED.add(buffer.len() as u64, attributes);
658 })
659 .inspect_err(|_| {
660 TCP_RECEIVE_ERRORS.add(1, attributes);
661 })?;
662
663 Ok(Bytes::from(buffer)).inspect(|frame| debug!(frame = ?&frame[..]))
664 }
665}
666
667impl Service<Object<ConnectionManager>, Bytes> for BytesConnectionService {
668 type Response = Bytes;
669 type Error = Error;
670
671 async fn serve(
672 &self,
673 mut ctx: Context<Object<ConnectionManager>>,
674 req: Bytes,
675 ) -> Result<Self::Response, Self::Error> {
676 let c = ctx.state_mut();
677
678 let local = c.stream.local_addr()?;
679 let peer = c.stream.peer_addr()?;
680
681 let attributes = [KeyValue::new("peer", peer.to_string())];
682
683 let span = span!(Level::DEBUG, "client", local = %local, peer = %peer);
684
685 async move {
686 self.write(&mut c.stream, req, &attributes).await?;
687
688 c.correlation_id += 1;
689
690 self.read(&mut c.stream, &attributes).await
691 }
692 .instrument(span)
693 .await
694 }
695}
696
697fn frame_length(encoded: [u8; 4]) -> usize {
698 i32::from_be_bytes(encoded) as usize + encoded.len()
699}
700
701static TCP_CONNECT_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
702 METER
703 .u64_histogram("tcp_connect_duration")
704 .with_unit("ms")
705 .with_description("The TCP connect latencies in milliseconds")
706 .build()
707});
708
709static TCP_CONNECT_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
710 METER
711 .u64_counter("tcp_connect_errors")
712 .with_description("TCP connect errors")
713 .build()
714});
715
716static TCP_SEND_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
717 METER
718 .u64_histogram("tcp_send_duration")
719 .with_unit("ms")
720 .with_description("The TCP send latencies in milliseconds")
721 .build()
722});
723
724static TCP_SEND_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
725 METER
726 .u64_counter("tcp_send_errors")
727 .with_description("TCP send errors")
728 .build()
729});
730
731static TCP_RECEIVE_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
732 METER
733 .u64_histogram("tcp_receive_duration")
734 .with_unit("ms")
735 .with_description("The TCP receive latencies in milliseconds")
736 .build()
737});
738
739static TCP_RECEIVE_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
740 METER
741 .u64_counter("tcp_receive_errors")
742 .with_description("TCP receive errors")
743 .build()
744});
745
746static TCP_BYTES_SENT: LazyLock<Counter<u64>> = LazyLock::new(|| {
747 METER
748 .u64_counter("tcp_bytes_sent")
749 .with_description("TCP bytes sent")
750 .build()
751});
752
753static TCP_BYTES_RECEIVED: LazyLock<Counter<u64>> = LazyLock::new(|| {
754 METER
755 .u64_counter("tcp_bytes_received")
756 .with_description("TCP bytes received")
757 .build()
758});
759
760static POOL_MAX_SIZE: LazyLock<Gauge<u64>> = LazyLock::new(|| {
761 METER
762 .u64_gauge("pool_max_size")
763 .with_description("The maximum size of the pool")
764 .build()
765});
766
767static POOL_CURRENT_SIZE: LazyLock<Gauge<u64>> = LazyLock::new(|| {
768 METER
769 .u64_gauge("pool_current_size")
770 .with_description("The current size of the pool")
771 .build()
772});
773
774static POOL_AVAILABLE: LazyLock<Gauge<u64>> = LazyLock::new(|| {
775 METER
776 .u64_gauge("pool_available")
777 .with_description("The number of available objects in the pool")
778 .build()
779});
780
781static POOL_WAITING: LazyLock<Gauge<u64>> = LazyLock::new(|| {
782 METER
783 .u64_gauge("pool_waiting")
784 .with_description("The number of waiting objects in the pool")
785 .build()
786});
787
788#[cfg(test)]
789mod tests {
790 use std::{fs::File, thread};
791
792 use tansu_sans_io::{MetadataRequest, MetadataResponse};
793 use tansu_service::{
794 BytesFrameLayer, FrameRouteService, RequestLayer, ResponseService, TcpBytesLayer,
795 TcpContextLayer, TcpListenerLayer,
796 };
797 use tokio::{net::TcpListener, task::JoinSet};
798 use tokio_util::sync::CancellationToken;
799 use tracing::subscriber::DefaultGuard;
800 use tracing_subscriber::EnvFilter;
801
802 use super::*;
803
804 fn init_tracing() -> Result<DefaultGuard, Error> {
805 Ok(tracing::subscriber::set_default(
806 tracing_subscriber::fmt()
807 .with_level(true)
808 .with_line_number(true)
809 .with_thread_names(false)
810 .with_env_filter(
811 EnvFilter::from_default_env()
812 .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
813 )
814 .with_writer(
815 thread::current()
816 .name()
817 .ok_or(Error::Message(String::from("unnamed thread")))
818 .and_then(|name| {
819 File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
820 .map_err(Into::into)
821 })
822 .map(Arc::new)?,
823 )
824 .finish(),
825 ))
826 }
827
828 async fn server(cancellation: CancellationToken, listener: TcpListener) -> Result<(), Error> {
829 let server = (
830 TcpListenerLayer::new(cancellation),
831 TcpContextLayer::default(),
832 TcpBytesLayer::default(),
833 BytesFrameLayer::default(),
834 )
835 .into_layer(
836 FrameRouteService::builder()
837 .with_service(RequestLayer::<MetadataRequest>::new().into_layer(
838 ResponseService::new(|_ctx: Context<()>, _req: MetadataRequest| {
839 Ok::<_, Error>(
840 MetadataResponse::default()
841 .brokers(Some([].into()))
842 .topics(Some([].into()))
843 .cluster_id(Some("abc".into()))
844 .controller_id(Some(111))
845 .throttle_time_ms(Some(0))
846 .cluster_authorized_operations(Some(-1)),
847 )
848 }),
849 ))
850 .and_then(|builder| builder.build())?,
851 );
852
853 server.serve(Context::default(), listener).await
854 }
855
856 #[tokio::test]
857 async fn tcp_client_server() -> Result<(), Error> {
858 let _guard = init_tracing()?;
859
860 let cancellation = CancellationToken::new();
861 let listener = TcpListener::bind("127.0.0.1:0").await?;
862 let local_addr = listener.local_addr()?;
863
864 let mut join = JoinSet::new();
865
866 let _server = {
867 let cancellation = cancellation.clone();
868 join.spawn(async move { server(cancellation, listener).await })
869 };
870
871 let origin = (
872 RequestPoolLayer::new(
873 ConnectionManager::builder(
874 Url::parse(&format!("tcp://{local_addr}")).inspect(|url| debug!(%url))?,
875 )
876 .client_id(Some(env!("CARGO_PKG_NAME").into()))
877 .build()
878 .await
879 .inspect(|pool| debug!(?pool))?,
880 ),
881 RequestConnectionLayer,
882 FrameBytesLayer,
883 )
884 .into_layer(BytesConnectionService);
885
886 let response = origin
887 .serve(
888 Context::default(),
889 MetadataRequest::default()
890 .topics(Some([].into()))
891 .allow_auto_topic_creation(Some(false))
892 .include_cluster_authorized_operations(Some(false))
893 .include_topic_authorized_operations(Some(false)),
894 )
895 .await?;
896
897 assert_eq!(Some("abc"), response.cluster_id.as_deref());
898 assert_eq!(Some(111), response.controller_id);
899
900 cancellation.cancel();
901
902 let joined = join.join_all().await;
903 debug!(?joined);
904
905 Ok(())
906 }
907}