asterisk_rs_agi/
server.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use tokio::io::BufReader;
5use tokio::net::TcpListener;
6use tokio::sync::{watch, Semaphore};
7
8use crate::channel::AgiChannel;
9use crate::error::{AgiError, Result};
10use crate::handler::AgiHandler;
11use crate::request::AgiRequest;
12
13#[derive(Clone)]
15pub struct ShutdownHandle {
16 tx: watch::Sender<bool>,
17}
18
19impl ShutdownHandle {
20 pub fn shutdown(&self) {
22 let _ = self.tx.send(true);
23 }
24}
25
26pub struct AgiServer<H: AgiHandler> {
28 listener: TcpListener,
29 handler: Arc<H>,
30 max_connections: Option<usize>,
31 shutdown_rx: watch::Receiver<bool>,
32}
33
34#[must_use]
36pub struct AgiServerBuilder<H> {
37 bind_addr: String,
38 handler: Option<H>,
39 max_connections: Option<usize>,
40}
41
42impl<H: AgiHandler> AgiServer<H> {
43 pub fn builder() -> AgiServerBuilder<H> {
45 AgiServerBuilder {
46 bind_addr: "0.0.0.0:4573".to_owned(),
47 handler: None,
48 max_connections: None,
49 }
50 }
51
52 pub async fn run(mut self) -> Result<()> {
56 let semaphore = self.max_connections.map(|n| Arc::new(Semaphore::new(n)));
57
58 loop {
59 tokio::select! {
60 result = self.listener.accept() => {
61 let (stream, peer) = match result {
62 Ok(conn) => conn,
63 Err(err) => {
64 tracing::warn!(%err, "failed to accept connection");
65 tokio::time::sleep(Duration::from_millis(100)).await;
67 continue;
68 }
69 };
70
71 tracing::debug!(%peer, "new AGI connection");
72
73 let handler = Arc::clone(&self.handler);
74
75 let permit = if let Some(sem) = &semaphore {
78 let acquire = sem.clone().acquire_owned();
79 tokio::select! {
80 result = acquire => match result {
81 Ok(p) => Some(p),
82 Err(_) => {
83 tracing::error!("connection semaphore closed unexpectedly");
85 return Err(AgiError::Io(std::io::Error::other(
86 "connection semaphore closed",
87 )));
88 }
89 },
90 _ = self.shutdown_rx.changed() => {
91 tracing::info!("AGI server shutting down");
92 return Ok(());
93 }
94 }
95 } else {
96 None
97 };
98
99 tokio::spawn(async move {
100 let _permit = permit;
102
103 if let Err(err) = handle_connection(handler, stream).await {
104 tracing::warn!(%peer, %err, "AGI session error");
105 }
106 });
107 }
108 result = self.shutdown_rx.changed() => {
109 if result.is_err() || *self.shutdown_rx.borrow() {
111 tracing::info!("AGI server shutting down");
112 return Ok(());
113 }
114 }
115 }
116 }
117 }
118}
119
120async fn handle_connection<H: AgiHandler>(
122 handler: Arc<H>,
123 stream: tokio::net::TcpStream,
124) -> Result<()> {
125 let (read_half, write_half) = stream.into_split();
126 let mut reader = BufReader::new(read_half);
127
128 let request = match tokio::time::timeout(
130 Duration::from_secs(30),
131 AgiRequest::parse_from_reader(&mut reader),
132 )
133 .await
134 {
135 Ok(result) => result?,
136 Err(_elapsed) => {
137 tracing::warn!("AGI prelude read timed out after 30s");
138 return Ok(());
139 }
140 };
141
142 let channel = AgiChannel::new(reader, write_half);
143 handler.handle(request, channel).await
144}
145
146impl<H: AgiHandler> AgiServerBuilder<H> {
147 pub fn bind(mut self, addr: impl Into<String>) -> Self {
149 self.bind_addr = addr.into();
150 self
151 }
152
153 pub fn handler(mut self, handler: H) -> Self {
155 self.handler = Some(handler);
156 self
157 }
158
159 pub fn max_connections(mut self, n: usize) -> Self {
161 self.max_connections = Some(n);
162 self
163 }
164
165 pub async fn build(self) -> Result<(AgiServer<H>, ShutdownHandle)> {
169 let handler = self.handler.ok_or_else(|| {
170 AgiError::Io(std::io::Error::new(
171 std::io::ErrorKind::InvalidInput,
172 "handler is required",
173 ))
174 })?;
175
176 let listener = TcpListener::bind(&self.bind_addr).await?;
177 let (shutdown_tx, shutdown_rx) = watch::channel(false);
178
179 tracing::info!(addr = %self.bind_addr, "FastAGI server bound");
180
181 let server = AgiServer {
182 listener,
183 handler: Arc::new(handler),
184 max_connections: self.max_connections,
185 shutdown_rx,
186 };
187
188 let handle = ShutdownHandle { tx: shutdown_tx };
189
190 Ok((server, handle))
191 }
192}