1use opentelemetry::{InstrumentationScope, global, metrics::Meter};
16use opentelemetry_otlp::ExporterBuildError;
17use opentelemetry_sdk::error::OTelSdkError;
18use opentelemetry_semantic_conventions::SCHEMA_URL;
19use rama::{
20 Context, Layer, Service,
21 layer::{HijackLayer, MapErrLayer, MapResponseLayer},
22};
23use std::{
24 fmt, io,
25 sync::{Arc, LazyLock},
26};
27use tansu_client::{
28 BytesConnectionService, ConnectionManager, FrameConnectionLayer, FramePoolLayer,
29 RequestConnectionLayer, RequestPoolLayer,
30};
31use tansu_otel::meter_provider;
32use tansu_sans_io::{
33 ApiKey, ErrorCode, MetadataRequest, MetadataResponse, ProduceRequest,
34 metadata_response::MetadataResponseBroker,
35};
36use tansu_service::{
37 BytesFrameLayer, FrameApiKeyMatcher, FrameBytesLayer, FrameRequestLayer, TcpBytesLayer,
38 TcpContextLayer, TcpListenerLayer, host_port,
39};
40use tokio::{
41 net::TcpListener,
42 task::{JoinError, JoinSet},
43};
44use tokio_util::sync::CancellationToken;
45use tracing::debug;
46use tracing_subscriber::filter::ParseError;
47use url::Url;
48
49use crate::{
50 produce::BatchProduceLayer,
51 topic::{ResourceConfig, ResourceConfigValue, ResourceConfigValueMatcher, TopicConfigLayer},
52};
53
54mod produce;
55mod topic;
56
57#[derive(Clone, Debug, thiserror::Error)]
58pub enum Error {
59 Client(#[from] tansu_client::Error),
60 ExporterBuild(Arc<ExporterBuildError>),
61 FrameTooBig(usize),
62 Io(Arc<io::Error>),
63 Join(Arc<JoinError>),
64 Otel(#[from] tansu_otel::Error),
65 OtelSdk(Arc<OTelSdkError>),
66 ParseFilter(Arc<ParseError>),
67 Protocol(#[from] tansu_sans_io::Error),
68 ResourceLock {
69 name: String,
70 key: Option<String>,
71 value: Option<ResourceConfigValue>,
72 },
73 Service(#[from] tansu_service::Error),
74 UnknownHost(Url),
75 Message(String),
76}
77
78impl fmt::Display for Error {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "{self:?}")
81 }
82}
83
84impl From<JoinError> for Error {
85 fn from(value: JoinError) -> Self {
86 Self::Join(Arc::new(value))
87 }
88}
89
90impl From<OTelSdkError> for Error {
91 fn from(value: OTelSdkError) -> Self {
92 Self::OtelSdk(Arc::new(value))
93 }
94}
95
96impl From<ExporterBuildError> for Error {
97 fn from(value: ExporterBuildError) -> Self {
98 Self::ExporterBuild(Arc::new(value))
99 }
100}
101
102impl From<ParseError> for Error {
103 fn from(value: ParseError) -> Self {
104 Self::ParseFilter(Arc::new(value))
105 }
106}
107
108impl From<io::Error> for Error {
109 fn from(value: io::Error) -> Self {
110 Self::Io(Arc::new(value))
111 }
112}
113
114pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
115 global::meter_with_scope(
116 InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
117 .with_version(env!("CARGO_PKG_VERSION"))
118 .with_schema_url(SCHEMA_URL)
119 .build(),
120 )
121});
122
123#[derive(Clone, Debug)]
124pub struct Proxy {
125 listener: Url,
126 origin: Url,
127}
128
129impl Proxy {
130 const NODE_ID: i32 = 111;
131
132 pub fn new(listener: Url, origin: Url) -> Self {
133 Self { listener, origin }
134 }
135
136 pub async fn listen(&self) -> Result<(), Error> {
137 debug!(%self.listener);
138
139 let configuration = ResourceConfig::default();
140
141 let listener = TcpListener::bind(host_port(self.listener.clone()).await?).await?;
142
143 let token = CancellationToken::new();
144
145 let pool = ConnectionManager::builder(self.origin.clone())
146 .client_id(Some(env!("CARGO_PKG_NAME").into()))
147 .build()
148 .await
149 .inspect(|pool| debug!(?pool))?;
150
151 let request_origin = (
152 MapErrLayer::new(Error::from),
153 RequestPoolLayer::new(pool.clone()),
154 RequestConnectionLayer,
155 FrameBytesLayer,
156 )
157 .into_layer(BytesConnectionService);
158
159 let frame_origin = (
160 MapErrLayer::new(Error::from),
161 FramePoolLayer::new(pool.clone()),
162 FrameConnectionLayer,
163 FrameBytesLayer,
164 )
165 .into_layer(BytesConnectionService);
166
167 let host = String::from(self.listener.host_str().unwrap_or("localhost"));
168 let port = i32::from(self.listener.port().unwrap_or(9092));
169
170 let meta = HijackLayer::new(
171 FrameApiKeyMatcher(MetadataRequest::KEY),
172 (
173 FrameRequestLayer::<MetadataRequest>::new(),
174 MapResponseLayer::new(move |response: MetadataResponse| {
175 response.brokers(Some(vec![
176 MetadataResponseBroker::default()
177 .node_id(Self::NODE_ID)
178 .host(host)
179 .port(port)
180 .rack(None),
181 ]))
182 }),
183 )
184 .into_layer(request_origin.clone()),
185 );
186
187 let produce = HijackLayer::new(
188 FrameApiKeyMatcher(ProduceRequest::KEY),
189 (
190 FrameRequestLayer::<ProduceRequest>::new(),
191 TopicConfigLayer::new(configuration.clone(), request_origin.clone()),
192 )
193 .into_layer(
194 HijackLayer::new(
195 ResourceConfigValueMatcher::new(
196 configuration.clone(),
197 "tansu.batch",
198 "true",
199 ),
200 BatchProduceLayer::new(configuration.clone())
201 .into_layer(request_origin.clone()),
202 )
203 .into_layer(request_origin.clone()),
204 ),
205 );
206
207 let s = (
208 TcpListenerLayer::new(token),
209 TcpContextLayer::default(),
210 TcpBytesLayer::<()>::default(),
211 BytesFrameLayer,
212 meta,
213 produce,
214 )
215 .into_layer(frame_origin);
216
217 s.serve(Context::with_state(()), listener).await?;
218
219 Ok(())
220 }
221
222 pub async fn main(
223 listener_url: Url,
224 origin_url: Url,
225 otlp_endpoint_url: Option<Url>,
226 ) -> Result<ErrorCode, Error> {
227 let mut set = JoinSet::new();
228
229 let meter_provider = otlp_endpoint_url.map_or(Ok(None), |otlp_endpoint_url| {
230 meter_provider(otlp_endpoint_url, env!("CARGO_PKG_NAME")).map(Some)
231 })?;
232
233 {
234 let proxy = Proxy::new(listener_url, origin_url);
235 _ = set.spawn(async move { proxy.listen().await.unwrap() });
236 }
237
238 loop {
239 if set.join_next().await.is_none() {
240 break;
241 }
242 }
243
244 if let Some(meter_provider) = meter_provider {
245 meter_provider
246 .force_flush()
247 .inspect(|force_flush| debug!(?force_flush))?;
248
249 meter_provider
250 .shutdown()
251 .inspect(|shutdown| debug!(?shutdown))?;
252 }
253
254 Ok(ErrorCode::None)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use std::{fs::File, sync::Arc, thread};
261
262 use tansu_sans_io::{
263 DescribeConfigsRequest, DescribeConfigsResponse, Frame, Header, ProduceResponse,
264 };
265 use tansu_service::{FrameService, RequestApiKeyMatcher, ResponseService};
266 use tracing::subscriber::DefaultGuard;
267 use tracing_subscriber::EnvFilter;
268
269 use super::*;
270
271 fn init_tracing() -> Result<DefaultGuard, Error> {
272 Ok(tracing::subscriber::set_default(
273 tracing_subscriber::fmt()
274 .with_level(true)
275 .with_line_number(true)
276 .with_thread_names(false)
277 .with_env_filter(
278 EnvFilter::from_default_env()
279 .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
280 )
281 .with_writer(
282 thread::current()
283 .name()
284 .ok_or(Error::Message(String::from("unnamed thread")))
285 .and_then(|name| {
286 File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
287 .map_err(Into::into)
288 })
289 .map(Arc::new)?,
290 )
291 .finish(),
292 ))
293 }
294
295 #[tokio::test]
296 async fn produce_hijack() -> Result<(), Error> {
297 let _guard = init_tracing()?;
298
299 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
300
301 let produce =
302 HijackLayer::new(
303 FrameApiKeyMatcher(ProduceRequest::KEY),
304 FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
305 |_ctx: Context<()>, _req: ProduceRequest| {
306 Ok::<_, Error>(
307 ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
308 )
309 },
310 )),
311 )
312 .into_layer(FrameRequestLayer::<ProduceRequest>::new().into_layer(
313 ResponseService::new(|_ctx: Context<()>, _req: ProduceRequest| {
314 Ok::<_, Error>(ProduceResponse::default())
315 }),
316 ));
317
318 let frame = produce
319 .serve(
320 Context::default(),
321 Frame {
322 size: 0,
323 header: Header::Request {
324 api_key: ProduceRequest::KEY,
325 api_version: 12,
326 correlation_id: 12321,
327 client_id: Some("abc".into()),
328 },
329 body: ProduceRequest::default().into(),
330 },
331 )
332 .await?;
333
334 let response = ProduceResponse::try_from(frame.body)?;
335 assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
336
337 Ok(())
338 }
339
340 #[tokio::test]
341 async fn request_api_matcher() -> Result<(), Error> {
342 let _guard = init_tracing()?;
343
344 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
345
346 let service = HijackLayer::new(
347 RequestApiKeyMatcher(ProduceRequest::KEY),
348 ResponseService::new(|_, _req: ProduceRequest| {
349 Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
350 }),
351 )
352 .into_layer(ResponseService::new(|_, _req: ProduceRequest| {
353 Ok::<_, Error>(ProduceResponse::default())
354 }));
355
356 let response = service
357 .serve(Context::default(), ProduceRequest::default())
358 .await?;
359
360 assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
361
362 Ok(())
363 }
364
365 #[tokio::test]
366 async fn frame_topic_config() -> Result<(), Error> {
367 let _guard = init_tracing()?;
368
369 let configuration = ResourceConfig::default();
370 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
371
372 let service = HijackLayer::new(
373 FrameApiKeyMatcher(ProduceRequest::KEY),
374 (
375 FrameRequestLayer::<ProduceRequest>::new(),
376 TopicConfigLayer::new(
377 configuration.clone(),
378 ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
379 Ok::<_, Error>(DescribeConfigsResponse::default())
380 }),
381 ),
382 )
383 .into_layer(ResponseService::new(
384 |_: Context<()>, _req: ProduceRequest| {
385 Ok::<_, Error>(
386 ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
387 )
388 },
389 )),
390 )
391 .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
392 Ok::<_, Error>(Frame {
393 size: 0,
394 header: Header::Response {
395 correlation_id: 12321,
396 },
397 body: MetadataResponse::default().into(),
398 })
399 }));
400
401 let response = service
402 .serve(
403 Context::default(),
404 Frame {
405 size: 0,
406 header: Header::Request {
407 api_key: ProduceRequest::KEY,
408 api_version: 123,
409 correlation_id: 321,
410 client_id: Some("abc".into()),
411 },
412 body: ProduceRequest::default().into(),
413 },
414 )
415 .await?;
416
417 assert!(ProduceResponse::try_from(response.body).is_ok());
418
419 Ok(())
420 }
421
422 #[tokio::test]
423 async fn response_topic_config() -> Result<(), Error> {
424 let configuration = ResourceConfig::default();
425 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
426
427 let service = TopicConfigLayer::new(
428 configuration,
429 ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
430 Ok::<_, Error>(DescribeConfigsResponse::default())
431 }),
432 )
433 .layer(ResponseService::new(
434 |_: Context<()>, _req: ProduceRequest| {
435 Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
436 },
437 ));
438
439 let response = service
440 .serve(Context::default(), ProduceRequest::default())
441 .await?;
442
443 assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
444
445 Ok(())
446 }
447
448 #[tokio::test]
449 async fn frame_api_matcher() -> Result<(), Error> {
450 let service = HijackLayer::new(
451 FrameApiKeyMatcher(ProduceRequest::KEY),
452 FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
453 |_: Context<()>, _req: ProduceRequest| Ok::<_, Error>(ProduceResponse::default()),
454 )),
455 )
456 .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
457 Ok::<_, Error>(Frame {
458 size: 0,
459 header: Header::Response {
460 correlation_id: 12321,
461 },
462 body: MetadataResponse::default().into(),
463 })
464 }));
465
466 let response = service
467 .serve(
468 Context::default(),
469 Frame {
470 size: 0,
471 header: Header::Request {
472 api_key: ProduceRequest::KEY,
473 api_version: 123,
474 correlation_id: 321,
475 client_id: Some("abc".into()),
476 },
477 body: ProduceRequest::default().into(),
478 },
479 )
480 .await?;
481
482 assert!(ProduceResponse::try_from(response.body).is_ok());
483
484 let response = service
485 .serve(
486 Context::default(),
487 Frame {
488 size: 0,
489 header: Header::Request {
490 api_key: MetadataRequest::KEY,
491 api_version: 123,
492 correlation_id: 321,
493 client_id: Some("abc".into()),
494 },
495 body: MetadataRequest::default().into(),
496 },
497 )
498 .await?;
499
500 assert!(MetadataResponse::try_from(response.body).is_ok());
501
502 Ok(())
503 }
504}