Skip to main content

agent_first_pay/mode/
rest.rs

1use crate::handler::{self, App};
2use crate::store;
3use crate::types::*;
4use axum::extract::State;
5use axum::http::{HeaderMap, StatusCode};
6use axum::response::IntoResponse;
7use axum::routing::post;
8use axum::Json;
9use std::io::Write;
10use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14pub struct RestInit {
15    pub listen: String,
16    pub api_key: Option<String>,
17    pub allow_public_listen: bool,
18    pub log: Vec<String>,
19    pub data_dir: Option<String>,
20    pub startup_argv: Vec<String>,
21    pub startup_args: serde_json::Value,
22    pub startup_requested: bool,
23}
24
25struct AppState {
26    app: Arc<App>,
27    api_key: String,
28    log: Vec<String>,
29    rate_limiter: Option<RateLimiter>,
30}
31
32/// Simple token-bucket rate limiter with concurrent request tracking.
33struct RateLimiter {
34    /// Requests per second (refill rate).
35    rps: u32,
36    /// Maximum concurrent in-flight requests.
37    max_concurrent: u32,
38    /// Current in-flight count.
39    in_flight: AtomicU32,
40    /// Available tokens (scaled by 1000 for sub-integer precision).
41    tokens_milli: AtomicU64,
42    /// Last refill timestamp in milliseconds.
43    last_refill_ms: AtomicU64,
44}
45
46impl RateLimiter {
47    fn new(config: &RateLimitConfig) -> Self {
48        let rps = config.requests_per_second;
49        Self {
50            rps,
51            max_concurrent: config.max_concurrent,
52            in_flight: AtomicU32::new(0),
53            tokens_milli: AtomicU64::new(u64::from(rps) * 1000),
54            last_refill_ms: AtomicU64::new(now_ms()),
55        }
56    }
57
58    /// Try to acquire a permit. Returns Err if rate-limited.
59    fn try_acquire(&self) -> Result<RateLimitGuard<'_>, ()> {
60        // Check concurrent limit
61        if self.max_concurrent > 0 {
62            let prev = self.in_flight.fetch_add(1, Ordering::Relaxed);
63            if prev >= self.max_concurrent {
64                self.in_flight.fetch_sub(1, Ordering::Relaxed);
65                return Err(());
66            }
67        }
68
69        // Token bucket check
70        if self.rps > 0 {
71            self.refill();
72            let cost = 1000u64;
73            loop {
74                let current = self.tokens_milli.load(Ordering::Relaxed);
75                if current < cost {
76                    if self.max_concurrent > 0 {
77                        self.in_flight.fetch_sub(1, Ordering::Relaxed);
78                    }
79                    return Err(());
80                }
81                if self
82                    .tokens_milli
83                    .compare_exchange_weak(
84                        current,
85                        current - cost,
86                        Ordering::Relaxed,
87                        Ordering::Relaxed,
88                    )
89                    .is_ok()
90                {
91                    break;
92                }
93            }
94        }
95
96        Ok(RateLimitGuard { limiter: self })
97    }
98
99    fn refill(&self) {
100        let now = now_ms();
101        let last = self.last_refill_ms.load(Ordering::Relaxed);
102        let elapsed = now.saturating_sub(last);
103        if elapsed == 0 {
104            return;
105        }
106        if self
107            .last_refill_ms
108            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
109            .is_ok()
110        {
111            let add = elapsed * u64::from(self.rps); // milli-tokens per ms = rps
112            let max = u64::from(self.rps) * 1000;
113            let _ =
114                self.tokens_milli
115                    .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
116                        Some(current.saturating_add(add).min(max))
117                    });
118        }
119    }
120}
121
122struct RateLimitGuard<'a> {
123    limiter: &'a RateLimiter,
124}
125
126impl Drop for RateLimitGuard<'_> {
127    fn drop(&mut self) {
128        if self.limiter.max_concurrent > 0 {
129            self.limiter.in_flight.fetch_sub(1, Ordering::Relaxed);
130        }
131    }
132}
133
134fn now_ms() -> u64 {
135    std::time::SystemTime::now()
136        .duration_since(std::time::UNIX_EPOCH)
137        .map(|d| d.as_millis() as u64)
138        .unwrap_or(0)
139}
140
141pub async fn run_rest(init: RestInit) {
142    let api_key: String = match init
143        .api_key
144        .or_else(|| std::env::var("AFPAY_REST_API_KEY").ok())
145    {
146        Some(s) if !s.is_empty() => s,
147        _ => {
148            let value = agent_first_data::build_cli_error(
149                "--rest-api-key is required for REST mode",
150                Some("pass an API key for bearer authentication or set AFPAY_REST_API_KEY"),
151            );
152            let rendered =
153                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
154            let _ = writeln!(std::io::stdout(), "{rendered}");
155            std::process::exit(1);
156        }
157    };
158
159    let resolved_dir = init
160        .data_dir
161        .unwrap_or_else(|| RuntimeConfig::default().data_dir);
162    let mut config = match RuntimeConfig::load_from_dir(&resolved_dir) {
163        Ok(c) => c,
164        Err(e) => {
165            let value = agent_first_data::build_cli_error(&e, None);
166            let rendered =
167                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
168            let _ = writeln!(std::io::stdout(), "{rendered}");
169            std::process::exit(1);
170        }
171    };
172    if !init.log.is_empty() {
173        config.log = init.log.clone();
174    }
175
176    // Emit startup log
177    if let Some(startup) = crate::config::maybe_startup_log(
178        &config.log,
179        init.startup_requested,
180        Some(init.startup_argv),
181        Some(&config),
182        init.startup_args,
183    ) {
184        let value = serde_json::to_value(&startup).unwrap_or(serde_json::Value::Null);
185        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
186        let _ = writeln!(std::io::stdout(), "{rendered}");
187    }
188
189    let startup_errors = handler::startup_provider_validation_errors(&config).await;
190    for error_output in &startup_errors {
191        let value = serde_json::to_value(error_output).unwrap_or(serde_json::Value::Null);
192        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
193        let _ = writeln!(std::io::stdout(), "{rendered}");
194    }
195    if !startup_errors.is_empty() {
196        std::process::exit(1);
197    }
198
199    let rate_limiter = config.rate_limit.as_ref().map(RateLimiter::new);
200    let (tx, _rx) = mpsc::channel::<Output>(4096);
201    let st = store::create_storage_backend(&config);
202    let app = Arc::new(App::new(config, tx, Some(true), st));
203    let state = Arc::new(AppState {
204        app,
205        api_key,
206        log: init.log,
207        rate_limiter,
208    });
209
210    let router = axum::Router::new()
211        .route("/v1/afpay", post(handle_call))
212        .with_state(state);
213
214    let addr: std::net::SocketAddr = match init.listen.parse() {
215        Ok(a) => a,
216        Err(e) => {
217            let value = agent_first_data::build_cli_error(
218                &format!("invalid --rest-listen address: {e}"),
219                Some("expected format: host:port (e.g. 0.0.0.0:9401)"),
220            );
221            let rendered =
222                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
223            let _ = writeln!(std::io::stdout(), "{rendered}");
224            std::process::exit(1);
225        }
226    };
227    if public_listen_requires_ack(addr) && !init.allow_public_listen {
228        let value = agent_first_data::build_cli_error(
229            "refusing to bind REST to a non-loopback address without --public-listen",
230            Some(
231                "use the default 127.0.0.1:9401, or pass --public-listen only behind TLS/firewall",
232            ),
233        );
234        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
235        let _ = writeln!(std::io::stdout(), "{rendered}");
236        std::process::exit(1);
237    }
238
239    let listener = match tokio::net::TcpListener::bind(addr).await {
240        Ok(l) => l,
241        Err(e) => {
242            let value = agent_first_data::build_cli_error(&format!("REST bind failed: {e}"), None);
243            let rendered =
244                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
245            let _ = writeln!(std::io::stdout(), "{rendered}");
246            std::process::exit(1);
247        }
248    };
249
250    if let Err(e) = axum::serve(listener, router).await {
251        let value = agent_first_data::build_cli_error(&format!("REST server error: {e}"), None);
252        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
253        let _ = writeln!(std::io::stdout(), "{rendered}");
254        std::process::exit(1);
255    }
256}
257
258fn public_listen_requires_ack(addr: std::net::SocketAddr) -> bool {
259    !addr.ip().is_loopback()
260}
261
262fn check_auth(headers: &HeaderMap, expected: &str) -> Result<(), StatusCode> {
263    // Try Authorization: Bearer <key>
264    if let Some(val) = headers.get("authorization") {
265        let val = val.to_str().map_err(|_| StatusCode::UNAUTHORIZED)?;
266        if let Some(token) = val.strip_prefix("Bearer ") {
267            if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
268                return Ok(());
269            }
270        }
271    }
272    // Try X-API-Key: <key>
273    if let Some(val) = headers.get("x-api-key") {
274        let val = val.to_str().map_err(|_| StatusCode::UNAUTHORIZED)?;
275        if constant_time_eq(val.as_bytes(), expected.as_bytes()) {
276            return Ok(());
277        }
278    }
279    Err(StatusCode::UNAUTHORIZED)
280}
281
282/// Constant-time byte comparison to prevent timing attacks on API key.
283fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
284    if a.len() != b.len() {
285        return false;
286    }
287    let mut diff = 0u8;
288    for (x, y) in a.iter().zip(b.iter()) {
289        diff |= x ^ y;
290    }
291    diff == 0
292}
293
294async fn handle_call(
295    State(state): State<Arc<AppState>>,
296    headers: HeaderMap,
297    body: axum::body::Bytes,
298) -> impl IntoResponse {
299    // Rate limit check
300    let _rate_guard = if let Some(rl) = &state.rate_limiter {
301        match rl.try_acquire() {
302            Ok(guard) => Some(guard),
303            Err(()) => {
304                return (
305                    StatusCode::TOO_MANY_REQUESTS,
306                    Json(serde_json::json!({
307                        "code": "error",
308                        "error": "rate limit exceeded",
309                    })),
310                );
311            }
312        }
313    } else {
314        None
315    };
316
317    // Auth check
318    if let Err(status) = check_auth(&headers, &state.api_key) {
319        return (
320            status,
321            Json(serde_json::json!({
322                "code": "error",
323                "error": "unauthorized",
324            })),
325        );
326    }
327
328    // Parse Input from body
329    let input: Input = match serde_json::from_slice(&body) {
330        Ok(v) => v,
331        Err(e) => {
332            return (
333                StatusCode::BAD_REQUEST,
334                Json(serde_json::json!({
335                    "code": "error",
336                    "error": format!("invalid input: {e}"),
337                })),
338            );
339        }
340    };
341
342    // Block local-only operations
343    if input.is_local_only() {
344        return (
345            StatusCode::FORBIDDEN,
346            Json(serde_json::json!({
347                "code": "error",
348                "error": "local-only operation not allowed over REST",
349            })),
350        );
351    }
352
353    // Create per-request channel and App
354    let (tx, mut rx) = mpsc::channel::<Output>(256);
355    let config = state.app.config.read().await.clone();
356    let st = store::create_storage_backend(&config);
357    let app = Arc::new(App::new(config, tx, Some(true), st));
358    app.requests_total.fetch_add(1, Ordering::Relaxed);
359
360    // Dispatch
361    handler::dispatch(&app, input).await;
362
363    // Collect outputs
364    drop(app);
365    let mut outputs = Vec::new();
366    while let Some(out) = rx.recv().await {
367        // Mirror log events to daemon stdout
368        if let Output::Log { ref event, .. } = out {
369            if log_event_enabled(&state.log, event) {
370                let rendered = agent_first_data::cli_output(
371                    &serde_json::to_value(&out).unwrap_or(serde_json::Value::Null),
372                    agent_first_data::OutputFormat::Json,
373                );
374                let _ = writeln!(std::io::stdout(), "{rendered}");
375            }
376        }
377        let value = serde_json::to_value(&out).unwrap_or(serde_json::Value::Null);
378        outputs.push(value);
379    }
380
381    // Check if any output is an error
382    let has_error = outputs
383        .iter()
384        .any(|item| item.get("code").and_then(|v| v.as_str()) == Some("error"));
385
386    let status = if has_error {
387        StatusCode::UNPROCESSABLE_ENTITY
388    } else {
389        StatusCode::OK
390    };
391
392    (status, Json(serde_json::Value::Array(outputs)))
393}
394
395fn log_event_enabled(log_filters: &[String], event: &str) -> bool {
396    if log_filters.is_empty() {
397        return false;
398    }
399    let ev = event.to_ascii_lowercase();
400    log_filters
401        .iter()
402        .any(|f| f == "*" || f == "all" || ev.starts_with(f.as_str()))
403}