1use allora_core::adapter::{ensure_correlation, BaseAdapter, InboundAdapter};
44use allora_core::channel::{ChannelRef, QueueChannel};
45use allora_core::endpoint::{EndpointSource, InMemoryEndpoint};
47use allora_core::error::Result;
48use allora_core::{Exchange, Message, Payload};
49use async_trait::async_trait;
50use hyper::service::{make_service_fn, service_fn};
51use hyper::{Body, Request, Response, Server, Version};
52use std::collections::HashMap;
53use std::net::SocketAddr;
54use std::pin::Pin;
55use std::sync::{Arc, Mutex, Weak};
56use std::task::{Context, Poll};
57use tracing::{debug, error, info, trace};
58
59const REPLY_TIMEOUT_SECS: u64 = 3;
61const REPLY_POLL_INTERVAL_MILLIS: u64 = 50;
62
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum Mep {
66 InOut,
68 InOnly202,
70}
71impl Default for Mep {
72 fn default() -> Self {
73 Mep::InOut
74 }
75}
76
77#[derive(Clone, Debug)]
78pub struct HttpInboundAdapter {
79 id: String,
80 addr: SocketAddr,
81 base_path: String,
82 channel: ChannelRef,
83 mep: Mep,
84 reply_channel: Option<ChannelRef>,
85 routes: Arc<Mutex<HashMap<(String, String), Vec<Weak<InMemoryEndpoint>>>>>,
86}
87
88pub struct HttpServerHandle {
89 join: tokio::task::JoinHandle<Result<()>>,
90}
91
92impl HttpServerHandle {
93 pub async fn wait(self) -> Result<()> {
94 self.join
95 .await
96 .unwrap_or_else(|e| Err(allora_core::error::Error::other(e.to_string())))
97 }
98 pub fn abort(self) {
99 self.join.abort();
100 }
101}
102
103impl std::future::Future for HttpServerHandle {
104 type Output = Result<()>;
105 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
106 let inner = unsafe { self.map_unchecked_mut(|s| &mut s.join) };
107 match inner.poll(cx) {
108 Poll::Ready(r) => Poll::Ready(
109 r.unwrap_or_else(|e| Err(allora_core::error::Error::other(e.to_string()))),
110 ),
111 Poll::Pending => Poll::Pending,
112 }
113 }
114}
115
116impl HttpInboundAdapter {
117 pub fn id(&self) -> &str {
118 &self.id
119 }
120 pub fn addr(&self) -> SocketAddr {
121 self.addr
122 }
123 pub fn base_path(&self) -> &str {
124 &self.base_path
125 }
126 pub fn mep(&self) -> Mep {
127 self.mep
128 }
129 pub fn new(
134 host: impl Into<String>,
135 port: u16,
136 base_path: impl Into<String>,
137 channel: ChannelRef,
138 reply_channel: Option<ChannelRef>,
139 mep: Mep,
140 id: Option<String>,
141 ) -> Self {
142 let host_str = host.into();
143 let addr: SocketAddr = format!("{}:{}", host_str, port)
144 .parse()
145 .expect("invalid socket addr");
146 let base = {
147 let b = base_path.into();
148 if b.is_empty() {
149 "/".to_string()
150 } else {
151 b
152 }
153 };
154 let id_final = id.unwrap_or_else(|| format!("http-inbound:{}", addr));
155 trace!(adapter.id=%id_final, host=%host_str, port=%port, base_path=%base, mep=?mep, "constructing HttpInboundAdapter (direct)");
156 HttpInboundAdapter {
157 id: id_final,
158 addr,
159 base_path: base,
160 channel,
161 mep,
162 reply_channel,
163 routes: Arc::new(Mutex::new(HashMap::new())),
164 }
165 }
166 pub fn new_in_out(
168 host: impl Into<String>,
169 port: u16,
170 base_path: impl Into<String>,
171 channel: ChannelRef,
172 reply_channel: Option<ChannelRef>,
173 id: Option<String>,
174 ) -> Self {
175 Self::new(
176 host,
177 port,
178 base_path,
179 channel,
180 reply_channel,
181 Mep::InOut,
182 id,
183 )
184 }
185 pub fn new_in_only_202(
187 host: impl Into<String>,
188 port: u16,
189 base_path: impl Into<String>,
190 channel: ChannelRef,
191 id: Option<String>,
192 ) -> Self {
193 Self::new(host, port, base_path, channel, None, Mep::InOnly202, id)
194 }
195 }
197
198pub struct HttpInboundBuilder {
199 id: Option<String>,
200 host: String,
201 port: u16,
202 base_path: String,
203 channel: Option<ChannelRef>,
204 mep: Mep,
205 reply_channel: Option<ChannelRef>,
206 registrations: Vec<(String, String, Arc<InMemoryEndpoint>)>,
207}
208impl HttpInboundBuilder {
209 pub(crate) fn new() -> Self {
210 Self {
211 id: None,
212 host: String::new(),
213 port: 0,
214 base_path: String::new(),
215 channel: None,
216 mep: Mep::InOut,
217 reply_channel: None,
218 registrations: Vec::new(),
219 }
220 }
221 pub fn register(mut self, method: &str, path: &str, endpoint: Arc<InMemoryEndpoint>) -> Self {
223 let norm = if path.starts_with('/') {
224 path.to_string()
225 } else {
226 format!("/{}", path)
227 };
228 self.registrations
229 .push((method.to_ascii_uppercase(), norm, endpoint));
230 self
231 }
232 pub fn register_any(self, path: &str, endpoint: Arc<InMemoryEndpoint>) -> Self {
234 self.register("ANY", path, endpoint)
235 }
236 pub fn id(mut self, id: impl Into<String>) -> Self {
237 self.id = Some(id.into());
238 self
239 }
240 pub fn host(mut self, host: impl Into<String>) -> Self {
241 self.host = host.into();
242 self
243 }
244 pub fn port(mut self, port: u16) -> Self {
245 self.port = port;
246 self
247 }
248 pub fn base_path(mut self, path: impl Into<String>) -> Self {
249 self.base_path = path.into();
250 self
251 }
252 pub fn channel(mut self, ch: ChannelRef) -> Self {
253 self.channel = Some(ch);
254 self
255 }
256 pub fn reply_channel(mut self, ch: ChannelRef) -> Self {
257 self.reply_channel = Some(ch);
258 self
259 }
260
261 pub fn mep(mut self, mep: Mep) -> Self {
263 self.mep = mep;
264 self
265 }
266
267 pub fn in_only_202(self) -> Self {
269 self.mep(Mep::InOnly202)
270 }
271
272 pub fn build(self) -> HttpInboundAdapter {
273 let addr: SocketAddr = format!("{}:{}", self.host, self.port)
274 .parse()
275 .expect("invalid socket addr");
276 let id = self.id.unwrap_or_else(|| format!("http-inbound:{}", addr));
277 let base_path = if self.base_path.is_empty() { "/".to_string() } else { self.base_path };
278 let channel = self.channel.expect("channel must be set on HttpInboundBuilder before build()");
279 let effective_mep = if self.reply_channel.is_some() { Mep::InOut } else { self.mep };
280 let adapter = HttpInboundAdapter {
281 id: id.clone(),
282 addr,
283 base_path: base_path.clone(),
284 channel,
285 mep: effective_mep,
286 reply_channel: self.reply_channel.clone(),
287 routes: Arc::new(Mutex::new(HashMap::new())),
288 };
289 info!(adapter.id=%adapter.id, addr=%adapter.addr, base_path=%adapter.base_path, mep=?adapter.mep, reply_channel=adapter.reply_channel.is_some(), "HttpInboundAdapter built via builder");
290 for (method, path, ep) in self.registrations.into_iter() {
291 adapter.register_endpoint(&method, &path, Arc::downgrade(&ep));
292 }
293 adapter
294 }
295}
296
297impl BaseAdapter for HttpInboundAdapter {
298 fn id(&self) -> &str {
299 &self.id
300 }
301}
302
303#[async_trait]
304impl InboundAdapter for HttpInboundAdapter {
305 async fn run(&self) -> Result<()> {
306 self.serve().await
307 }
308}
309
310fn normalize_path<'a>(base: &'a str, full: &'a str) -> &'a str {
311 if base == "/" {
312 return full;
313 }
314 match full.strip_prefix(base) {
315 Some(p) if p.is_empty() => "/",
316 Some(p) => {
317 if p.starts_with('/') {
318 p
319 } else {
320 "/"
321 }
322 }
323 None => full,
324 }
325}
326
327fn http_version_str(v: Version) -> &'static str {
328 match v {
329 Version::HTTP_09 => "0.9",
330 Version::HTTP_10 => "1.0",
331 Version::HTTP_11 => "1.1",
332 Version::HTTP_2 => "2.0",
333 Version::HTTP_3 => "3.0",
334 _ => "unknown",
335 }
336}
337
338async fn adapt_request(
339 adapter_id: String,
340 channel: ChannelRef,
341 reply_channel: Option<ChannelRef>,
342 req: Request<Body>,
343 base_path: String,
344 mep: Mep,
345 routes: Arc<Mutex<HashMap<(String, String), Vec<Weak<InMemoryEndpoint>>>>>,
346) -> Result<Response<Body>> {
347 let method = req.method().clone();
348 let path_full = req.uri().path().to_string();
349 let path_norm = normalize_path(&base_path, &path_full).to_string();
350 let query = req.uri().query().unwrap_or("").to_string();
351 let version = http_version_str(req.version()).to_string();
352 let mut content_type = None::<String>;
354 let headers_clone: Vec<(String, String)> = req
355 .headers()
356 .iter()
357 .filter_map(|(name, val)| {
358 val.to_str()
359 .ok()
360 .map(|s| (name.as_str().to_ascii_lowercase(), s.to_string()))
361 })
362 .collect();
363 if let Some(ct) = headers_clone
364 .iter()
365 .find(|(k, _)| k == "content-type")
366 .map(|(_, v)| v.clone())
367 {
368 content_type = Some(ct);
369 }
370 let body_bytes = hyper::body::to_bytes(req.into_body())
372 .await
373 .map_err(|e| allora_core::error::Error::other(e.to_string()))?;
374
375 let mut msg = if let Ok(txt) = std::str::from_utf8(&body_bytes) {
377 Message::from_text(txt)
378 } else {
379 Message::new(Payload::Bytes(body_bytes.to_vec()))
380 };
381 msg.set_header("http.method", method.as_str());
382 msg.set_header("http.path", &path_norm);
383 if !query.is_empty() {
384 msg.set_header("http.query", &query);
385 }
386 msg.set_header("http.version", &version);
387 for (k, v) in headers_clone.iter() {
388 let key = format!("http.header.{}", k);
389 msg.set_header(&key, v);
390 }
391 if let Some(ct) = content_type {
392 msg.set_header("http.content_type", &ct);
393 }
394 if let Ok(txt) = std::str::from_utf8(&body_bytes) {
395 msg.set_header("http.body_text", txt);
396 }
397
398 let mut exchange = Exchange::new(msg);
400 ensure_correlation(&mut exchange);
401 debug!(adapter.id=%adapter_id, corr_id=?exchange.in_msg.header("corr_id"), "correlation ensured for inbound exchange");
402 match mep {
403 Mep::InOut => {
404 let key_exact = (method.as_str().to_ascii_uppercase(), path_norm.clone());
405 let key_any = ("ANY".to_string(), path_norm.clone());
406 let mut endpoints: Vec<Weak<InMemoryEndpoint>> = Vec::new();
407 if let Ok(map) = routes.lock() {
408 if let Some(list) = map.get(&key_exact) {
409 endpoints.extend(list.iter().cloned());
410 }
411 if let Some(list) = map.get(&key_any) {
412 endpoints.extend(list.iter().cloned());
413 }
414 }
415 if !endpoints.is_empty() {
416 debug!(adapter.id=%adapter_id, endpoints.count=endpoints.len(), path=%path_norm, "matched in-memory endpoints");
417 let mut response_body: Option<String> = None;
418 for weak_ep in endpoints.iter() {
419 if let Some(ep) = weak_ep.upgrade() {
420 if let Some(ch_ref) = ep.channel() {
421 let mut ex_clone = exchange.clone();
422 EndpointSource::Http {
423 adapter_id: adapter_id.clone(),
424 method: method.as_str().to_string(),
425 path: path_norm.clone(),
426 }
427 .apply_headers(&mut ex_clone);
428 ch_ref.send(ex_clone.clone()).await?;
429 trace!(adapter.id=%adapter_id, endpoint.channel=%ch_ref.id(), method=%method, path=%path_norm, "dispatched exchange to endpoint channel");
430 if response_body.is_none() {
431 response_body = ex_clone.in_msg.body_text().map(|s| s.to_string());
432 }
433 }
434 } else {
435 trace!(adapter.id=%adapter_id, method=%method, path=%path_norm, "skipping stale endpoint");
436 }
437 }
438 let body_final = response_body.unwrap_or_else(|| String::new());
439 return Ok(Response::new(Body::from(body_final)));
440 }
441 trace!(adapter.id=%adapter_id, channel.id=?channel.id(), mep=?mep, "no endpoints matched; sending to primary channel");
443 channel.send(exchange.clone()).await?;
444 if let Some(rc) = reply_channel {
445 if let Some(qc) = rc.as_any().downcast_ref::<QueueChannel>() {
447 use allora_core::PollableChannel;
448 let start = std::time::Instant::now();
449 while start.elapsed() < std::time::Duration::from_secs(REPLY_TIMEOUT_SECS) {
450 if let Some(ex_reply) = qc.try_receive().await {
451 let body = ex_reply
452 .out_msg
453 .as_ref()
454 .and_then(|m| m.body_text())
455 .or_else(|| ex_reply.in_msg.body_text())
456 .unwrap_or("");
457 return Ok(Response::new(Body::from(body.to_string())));
458 }
459 tokio::time::sleep(std::time::Duration::from_millis(
460 REPLY_POLL_INTERVAL_MILLIS,
461 ))
462 .await;
463 }
464 trace!(adapter.id=%adapter_id, "reply-channel timeout; returning original inbound body");
465 } else {
466 trace!(adapter.id=%adapter_id, "reply-channel present but not queue/pollable; skipping reply wait");
467 }
468 }
469 let response_body = exchange
470 .in_msg
471 .body_text()
472 .map(|s| s.to_string())
473 .unwrap_or_else(|| String::from_utf8_lossy(&body_bytes).to_string());
474 Ok(Response::new(Body::from(response_body)))
475 }
476 Mep::InOnly202 => {
477 trace!(adapter.id=%adapter_id, channel.id=?channel.id(), "IN_ONLY_202 mode: spawning background send");
478 let ch = channel.clone();
480 tokio::spawn(async move {
481 let _ = ch.send(exchange).await;
482 });
483 Ok(Response::builder()
484 .status(202)
485 .body(Body::from("ok"))
486 .unwrap())
487 }
488 }
489}
490
491impl HttpInboundAdapter {
492 pub fn register_endpoint(&self, method: &str, path: &str, ep: Weak<InMemoryEndpoint>) {
493 let key = (method.to_ascii_uppercase(), path.to_string());
494 let mut map = self.routes.lock().unwrap();
495 map.entry(key).or_insert_with(Vec::new).push(ep);
496 }
497 pub fn register_endpoint_any(&self, path: &str, ep: Weak<InMemoryEndpoint>) {
498 for m in [
499 "GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD", "ANY",
500 ] {
501 self.register_endpoint(m, path, ep.clone());
502 }
503 }
504 pub async fn serve(&self) -> Result<()> {
505 let channel = self.channel.clone();
506 let reply_channel = self.reply_channel.clone();
507 let base = self.base_path.clone();
508 let mep = self.mep;
509 let adapter_id = self.id.clone();
510 let routes_arc = self.routes.clone();
511 let make = make_service_fn(move |_conn| {
512 let channel_clone = channel.clone();
513 let base_clone = base.clone();
514 let adapter_id_clone = adapter_id.clone();
515 let routes_ref = routes_arc.clone();
516 let reply_channel_outer = reply_channel.clone();
517 async move {
518 Ok::<_, hyper::Error>(service_fn(move |req: Request<Body>| {
519 let c = channel_clone.clone();
520 let b = base_clone.clone();
521 let r = routes_ref.clone();
522 let a = adapter_id_clone.clone();
523 let rc = reply_channel_outer.clone();
524 async move {
525 match adapt_request(a, c, rc, req, b, mep, r).await {
526 Ok(resp) => Ok::<_, hyper::Error>(resp),
527 Err(e) => {
528 error!(error=%e, "request handling failed");
529 Ok(Response::builder()
530 .status(500)
531 .body(Body::from("internal error"))
532 .unwrap())
533 }
534 }
535 }
536 }))
537 }
538 });
539 info!(address=%self.addr, mep=?self.mep, "starting HTTP inbound adapter (continuous)");
540 Server::bind(&self.addr)
541 .serve(make)
542 .await
543 .map_err(|e| allora_core::error::Error::other(e.to_string()))?;
544 Ok(())
545 }
546 pub async fn run_once(self) -> Result<()> {
547 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
548 let channel = self.channel.clone();
549 let reply_channel = self.reply_channel.clone();
550 let base = self.base_path.clone();
551 let mep = self.mep;
552 let adapter_id = self.id.clone();
553 let routes_arc = self.routes.clone();
554 let shutdown_flag = Arc::new(Mutex::new(Some(tx)));
555 let make = make_service_fn(move |_conn| {
556 let channel_clone = channel.clone();
557 let base_clone = base.clone();
558 let adapter_id_clone = adapter_id.clone();
559 let routes_ref = routes_arc.clone();
560 let reply_channel_outer = reply_channel.clone();
561 let shutdown_inner = shutdown_flag.clone();
562 async move {
563 Ok::<_, hyper::Error>(service_fn(move |req: Request<Body>| {
564 let c = channel_clone.clone();
565 let b = base_clone.clone();
566 let r = routes_ref.clone();
567 let a = adapter_id_clone.clone();
568 let rc = reply_channel_outer.clone();
569 let shutdown_local = shutdown_inner.clone();
570 async move {
571 let result = adapt_request(a, c, rc, req, b, mep, r).await;
572 if let Some(sender) = shutdown_local.lock().unwrap().take() {
573 let _ = sender.send(());
574 }
575 match result {
576 Ok(resp) => Ok::<_, hyper::Error>(resp),
577 Err(e) => {
578 error!(error=%e, "request handling failed (run_once)");
579 Ok(Response::builder()
580 .status(500)
581 .body(Body::from("internal error"))
582 .unwrap())
583 }
584 }
585 }
586 }))
587 }
588 });
589 info!(address=%self.addr, mep=?self.mep, "starting HTTP inbound adapter (single request)");
590 Server::bind(&self.addr)
591 .serve(make)
592 .with_graceful_shutdown(async {
593 let _ = rx.await;
594 })
595 .await
596 .map_err(|e| allora_core::error::Error::other(e.to_string()))?;
597 Ok(())
598 }
599 pub fn spawn_once(self) -> HttpServerHandle {
600 HttpServerHandle {
601 join: tokio::spawn(async move { self.run_once().await }),
602 }
603 }
604 pub fn spawn_serve(self) -> HttpServerHandle {
605 HttpServerHandle {
606 join: tokio::spawn(async move { self.serve().await }),
607 }
608 }
609}