1use crate::handler::{self, App};
2use crate::store;
3use crate::types::*;
4use axum::extract::State;
5use axum::http::{HeaderMap, StatusCode};
6use axum::response::IntoResponse;
7use axum::routing::post;
8use axum::Json;
9use std::io::Write;
10use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14pub struct RestInit {
15 pub listen: String,
16 pub api_key: Option<String>,
17 pub allow_public_listen: bool,
18 pub log: Vec<String>,
19 pub data_dir: Option<String>,
20 pub startup_argv: Vec<String>,
21 pub startup_args: serde_json::Value,
22 pub startup_requested: bool,
23}
24
25struct AppState {
26 app: Arc<App>,
27 api_key: String,
28 log: Vec<String>,
29 rate_limiter: Option<RateLimiter>,
30}
31
32struct RateLimiter {
34 rps: u32,
36 max_concurrent: u32,
38 in_flight: AtomicU32,
40 tokens_milli: AtomicU64,
42 last_refill_ms: AtomicU64,
44}
45
46impl RateLimiter {
47 fn new(config: &RateLimitConfig) -> Self {
48 let rps = config.requests_per_second;
49 Self {
50 rps,
51 max_concurrent: config.max_concurrent,
52 in_flight: AtomicU32::new(0),
53 tokens_milli: AtomicU64::new(u64::from(rps) * 1000),
54 last_refill_ms: AtomicU64::new(now_ms()),
55 }
56 }
57
58 fn try_acquire(&self) -> Result<RateLimitGuard<'_>, ()> {
60 if self.max_concurrent > 0 {
62 let prev = self.in_flight.fetch_add(1, Ordering::Relaxed);
63 if prev >= self.max_concurrent {
64 self.in_flight.fetch_sub(1, Ordering::Relaxed);
65 return Err(());
66 }
67 }
68
69 if self.rps > 0 {
71 self.refill();
72 let cost = 1000u64;
73 loop {
74 let current = self.tokens_milli.load(Ordering::Relaxed);
75 if current < cost {
76 if self.max_concurrent > 0 {
77 self.in_flight.fetch_sub(1, Ordering::Relaxed);
78 }
79 return Err(());
80 }
81 if self
82 .tokens_milli
83 .compare_exchange_weak(
84 current,
85 current - cost,
86 Ordering::Relaxed,
87 Ordering::Relaxed,
88 )
89 .is_ok()
90 {
91 break;
92 }
93 }
94 }
95
96 Ok(RateLimitGuard { limiter: self })
97 }
98
99 fn refill(&self) {
100 let now = now_ms();
101 let last = self.last_refill_ms.load(Ordering::Relaxed);
102 let elapsed = now.saturating_sub(last);
103 if elapsed == 0 {
104 return;
105 }
106 if self
107 .last_refill_ms
108 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
109 .is_ok()
110 {
111 let add = elapsed * u64::from(self.rps); let max = u64::from(self.rps) * 1000;
113 let _ =
114 self.tokens_milli
115 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
116 Some(current.saturating_add(add).min(max))
117 });
118 }
119 }
120}
121
122struct RateLimitGuard<'a> {
123 limiter: &'a RateLimiter,
124}
125
126impl Drop for RateLimitGuard<'_> {
127 fn drop(&mut self) {
128 if self.limiter.max_concurrent > 0 {
129 self.limiter.in_flight.fetch_sub(1, Ordering::Relaxed);
130 }
131 }
132}
133
134fn now_ms() -> u64 {
135 std::time::SystemTime::now()
136 .duration_since(std::time::UNIX_EPOCH)
137 .map(|d| d.as_millis() as u64)
138 .unwrap_or(0)
139}
140
141pub async fn run_rest(init: RestInit) {
142 let api_key: String = match init
143 .api_key
144 .or_else(|| std::env::var("AFPAY_REST_API_KEY").ok())
145 {
146 Some(s) if !s.is_empty() => s,
147 _ => {
148 let value = agent_first_data::build_cli_error(
149 "--rest-api-key is required for REST mode",
150 Some("pass an API key for bearer authentication or set AFPAY_REST_API_KEY"),
151 );
152 let rendered =
153 agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
154 let _ = writeln!(std::io::stdout(), "{rendered}");
155 std::process::exit(1);
156 }
157 };
158
159 let resolved_dir = init
160 .data_dir
161 .unwrap_or_else(|| RuntimeConfig::default().data_dir);
162 let mut config = match RuntimeConfig::load_from_dir(&resolved_dir) {
163 Ok(c) => c,
164 Err(e) => {
165 let value = agent_first_data::build_cli_error(&e, None);
166 let rendered =
167 agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
168 let _ = writeln!(std::io::stdout(), "{rendered}");
169 std::process::exit(1);
170 }
171 };
172 if !init.log.is_empty() {
173 config.log = init.log.clone();
174 }
175
176 if let Some(startup) = crate::config::maybe_startup_log(
178 &config.log,
179 init.startup_requested,
180 Some(init.startup_argv),
181 Some(&config),
182 init.startup_args,
183 ) {
184 let value = serde_json::to_value(&startup).unwrap_or(serde_json::Value::Null);
185 let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
186 let _ = writeln!(std::io::stdout(), "{rendered}");
187 }
188
189 let startup_errors = handler::startup_provider_validation_errors(&config).await;
190 for error_output in &startup_errors {
191 let value = serde_json::to_value(error_output).unwrap_or(serde_json::Value::Null);
192 let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
193 let _ = writeln!(std::io::stdout(), "{rendered}");
194 }
195 if !startup_errors.is_empty() {
196 std::process::exit(1);
197 }
198
199 let rate_limiter = config.rate_limit.as_ref().map(RateLimiter::new);
200 let (tx, _rx) = mpsc::channel::<Output>(4096);
201 let st = store::create_storage_backend(&config);
202 let app = Arc::new(App::new(config, tx, Some(true), st));
203 let state = Arc::new(AppState {
204 app,
205 api_key,
206 log: init.log,
207 rate_limiter,
208 });
209
210 let router = axum::Router::new()
211 .route("/v1/afpay", post(handle_call))
212 .with_state(state);
213
214 let addr: std::net::SocketAddr = match init.listen.parse() {
215 Ok(a) => a,
216 Err(e) => {
217 let value = agent_first_data::build_cli_error(
218 &format!("invalid --rest-listen address: {e}"),
219 Some("expected format: host:port (e.g. 0.0.0.0:9401)"),
220 );
221 let rendered =
222 agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
223 let _ = writeln!(std::io::stdout(), "{rendered}");
224 std::process::exit(1);
225 }
226 };
227 if public_listen_requires_ack(addr) && !init.allow_public_listen {
228 let value = agent_first_data::build_cli_error(
229 "refusing to bind REST to a non-loopback address without --public-listen",
230 Some(
231 "use the default 127.0.0.1:9401, or pass --public-listen only behind TLS/firewall",
232 ),
233 );
234 let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
235 let _ = writeln!(std::io::stdout(), "{rendered}");
236 std::process::exit(1);
237 }
238
239 let listener = match tokio::net::TcpListener::bind(addr).await {
240 Ok(l) => l,
241 Err(e) => {
242 let value = agent_first_data::build_cli_error(&format!("REST bind failed: {e}"), None);
243 let rendered =
244 agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
245 let _ = writeln!(std::io::stdout(), "{rendered}");
246 std::process::exit(1);
247 }
248 };
249
250 if let Err(e) = axum::serve(listener, router).await {
251 let value = agent_first_data::build_cli_error(&format!("REST server error: {e}"), None);
252 let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
253 let _ = writeln!(std::io::stdout(), "{rendered}");
254 std::process::exit(1);
255 }
256}
257
258fn public_listen_requires_ack(addr: std::net::SocketAddr) -> bool {
259 !addr.ip().is_loopback()
260}
261
262fn check_auth(headers: &HeaderMap, expected: &str) -> Result<(), StatusCode> {
263 if let Some(val) = headers.get("authorization") {
265 let val = val.to_str().map_err(|_| StatusCode::UNAUTHORIZED)?;
266 if let Some(token) = val.strip_prefix("Bearer ") {
267 if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
268 return Ok(());
269 }
270 }
271 }
272 if let Some(val) = headers.get("x-api-key") {
274 let val = val.to_str().map_err(|_| StatusCode::UNAUTHORIZED)?;
275 if constant_time_eq(val.as_bytes(), expected.as_bytes()) {
276 return Ok(());
277 }
278 }
279 Err(StatusCode::UNAUTHORIZED)
280}
281
282fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
284 if a.len() != b.len() {
285 return false;
286 }
287 let mut diff = 0u8;
288 for (x, y) in a.iter().zip(b.iter()) {
289 diff |= x ^ y;
290 }
291 diff == 0
292}
293
294async fn handle_call(
295 State(state): State<Arc<AppState>>,
296 headers: HeaderMap,
297 body: axum::body::Bytes,
298) -> impl IntoResponse {
299 let _rate_guard = if let Some(rl) = &state.rate_limiter {
301 match rl.try_acquire() {
302 Ok(guard) => Some(guard),
303 Err(()) => {
304 return (
305 StatusCode::TOO_MANY_REQUESTS,
306 Json(serde_json::json!({
307 "code": "error",
308 "error": "rate limit exceeded",
309 })),
310 );
311 }
312 }
313 } else {
314 None
315 };
316
317 if let Err(status) = check_auth(&headers, &state.api_key) {
319 return (
320 status,
321 Json(serde_json::json!({
322 "code": "error",
323 "error": "unauthorized",
324 })),
325 );
326 }
327
328 let input: Input = match serde_json::from_slice(&body) {
330 Ok(v) => v,
331 Err(e) => {
332 return (
333 StatusCode::BAD_REQUEST,
334 Json(serde_json::json!({
335 "code": "error",
336 "error": format!("invalid input: {e}"),
337 })),
338 );
339 }
340 };
341
342 if input.is_local_only() {
344 return (
345 StatusCode::FORBIDDEN,
346 Json(serde_json::json!({
347 "code": "error",
348 "error": "local-only operation not allowed over REST",
349 })),
350 );
351 }
352
353 let (tx, mut rx) = mpsc::channel::<Output>(256);
355 let config = state.app.config.read().await.clone();
356 let st = store::create_storage_backend(&config);
357 let app = Arc::new(App::new(config, tx, Some(true), st));
358 app.requests_total.fetch_add(1, Ordering::Relaxed);
359
360 handler::dispatch(&app, input).await;
362
363 drop(app);
365 let mut outputs = Vec::new();
366 while let Some(out) = rx.recv().await {
367 if let Output::Log { ref event, .. } = out {
369 if log_event_enabled(&state.log, event) {
370 let rendered = agent_first_data::cli_output(
371 &serde_json::to_value(&out).unwrap_or(serde_json::Value::Null),
372 agent_first_data::OutputFormat::Json,
373 );
374 let _ = writeln!(std::io::stdout(), "{rendered}");
375 }
376 }
377 let value = serde_json::to_value(&out).unwrap_or(serde_json::Value::Null);
378 outputs.push(value);
379 }
380
381 let has_error = outputs
383 .iter()
384 .any(|item| item.get("code").and_then(|v| v.as_str()) == Some("error"));
385
386 let status = if has_error {
387 StatusCode::UNPROCESSABLE_ENTITY
388 } else {
389 StatusCode::OK
390 };
391
392 (status, Json(serde_json::Value::Array(outputs)))
393}
394
395fn log_event_enabled(log_filters: &[String], event: &str) -> bool {
396 if log_filters.is_empty() {
397 return false;
398 }
399 let ev = event.to_ascii_lowercase();
400 log_filters
401 .iter()
402 .any(|f| f == "*" || f == "all" || ev.starts_with(f.as_str()))
403}