Skip to main content

agent_first_pay/mode/rpc/
mod.rs

1pub mod crypto;
2
3use self::crypto::Cipher;
4use crate::handler::{self, App};
5use crate::types::*;
6use std::collections::{HashSet, VecDeque};
7use std::io::Write;
8use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use tokio::sync::mpsc;
11use tonic::Code;
12use tonic::{Request, Response, Status};
13
14pub struct RpcInit {
15    pub listen: String,
16    pub rpc_secret: Option<String>,
17    pub allow_public_listen: bool,
18    pub log: Vec<String>,
19    pub data_dir: Option<String>,
20    pub startup_argv: Vec<String>,
21    pub startup_args: serde_json::Value,
22    pub startup_requested: bool,
23}
24
25pub mod proto {
26    tonic::include_proto!("afpay");
27}
28
29use proto::af_pay_server::{AfPay, AfPayServer};
30use proto::{EncryptedRequest, EncryptedResponse};
31
32struct AfPayService {
33    cipher: Cipher,
34    config: RuntimeConfig,
35    rate_limiter: Option<RpcRateLimiter>,
36    replay_cache: Mutex<ReplayCache>,
37}
38
39struct ReplayCache {
40    seen: HashSet<Vec<u8>>,
41    order: VecDeque<Vec<u8>>,
42    max_entries: usize,
43}
44
45impl ReplayCache {
46    fn new(max_entries: usize) -> Self {
47        Self {
48            seen: HashSet::new(),
49            order: VecDeque::new(),
50            max_entries,
51        }
52    }
53
54    fn insert_unique(&mut self, nonce: &[u8]) -> bool {
55        let nonce = nonce.to_vec();
56        if !self.seen.insert(nonce.clone()) {
57            return false;
58        }
59        self.order.push_back(nonce);
60        while self.order.len() > self.max_entries {
61            if let Some(oldest) = self.order.pop_front() {
62                self.seen.remove(&oldest);
63            }
64        }
65        true
66    }
67}
68
69/// Simple token-bucket rate limiter for RPC.
70struct RpcRateLimiter {
71    rps: u32,
72    max_concurrent: u32,
73    in_flight: AtomicU32,
74    tokens_milli: AtomicU64,
75    last_refill_ms: AtomicU64,
76}
77
78impl RpcRateLimiter {
79    fn new(config: &RateLimitConfig) -> Self {
80        let rps = config.requests_per_second;
81        Self {
82            rps,
83            max_concurrent: config.max_concurrent,
84            in_flight: AtomicU32::new(0),
85            tokens_milli: AtomicU64::new(u64::from(rps) * 1000),
86            last_refill_ms: AtomicU64::new(rpc_now_ms()),
87        }
88    }
89
90    fn try_acquire(&self) -> Result<RpcRateLimitGuard<'_>, ()> {
91        if self.max_concurrent > 0 {
92            let prev = self.in_flight.fetch_add(1, Ordering::Relaxed);
93            if prev >= self.max_concurrent {
94                self.in_flight.fetch_sub(1, Ordering::Relaxed);
95                return Err(());
96            }
97        }
98        if self.rps > 0 {
99            self.refill();
100            let cost = 1000u64;
101            loop {
102                let current = self.tokens_milli.load(Ordering::Relaxed);
103                if current < cost {
104                    if self.max_concurrent > 0 {
105                        self.in_flight.fetch_sub(1, Ordering::Relaxed);
106                    }
107                    return Err(());
108                }
109                if self
110                    .tokens_milli
111                    .compare_exchange_weak(
112                        current,
113                        current - cost,
114                        Ordering::Relaxed,
115                        Ordering::Relaxed,
116                    )
117                    .is_ok()
118                {
119                    break;
120                }
121            }
122        }
123        Ok(RpcRateLimitGuard { limiter: self })
124    }
125
126    fn refill(&self) {
127        let now = rpc_now_ms();
128        let last = self.last_refill_ms.load(Ordering::Relaxed);
129        let elapsed = now.saturating_sub(last);
130        if elapsed == 0 {
131            return;
132        }
133        if self
134            .last_refill_ms
135            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
136            .is_ok()
137        {
138            let add = elapsed * u64::from(self.rps);
139            let max = u64::from(self.rps) * 1000;
140            let _ = self
141                .tokens_milli
142                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |c| {
143                    Some(c.saturating_add(add).min(max))
144                });
145        }
146    }
147}
148
149struct RpcRateLimitGuard<'a> {
150    limiter: &'a RpcRateLimiter,
151}
152
153impl Drop for RpcRateLimitGuard<'_> {
154    fn drop(&mut self) {
155        if self.limiter.max_concurrent > 0 {
156            self.limiter.in_flight.fetch_sub(1, Ordering::Relaxed);
157        }
158    }
159}
160
161fn rpc_now_ms() -> u64 {
162    std::time::SystemTime::now()
163        .duration_since(std::time::UNIX_EPOCH)
164        .map(|d| d.as_millis() as u64)
165        .unwrap_or(0)
166}
167
168#[tonic::async_trait]
169impl AfPay for AfPayService {
170    async fn call(
171        &self,
172        request: Request<EncryptedRequest>,
173    ) -> Result<Response<EncryptedResponse>, Status> {
174        let req = request.into_inner();
175
176        // Rate limit check
177        let _rate_guard = if let Some(rl) = &self.rate_limiter {
178            match rl.try_acquire() {
179                Ok(guard) => Some(guard),
180                Err(()) => {
181                    return Err(Status::resource_exhausted("rate limit exceeded"));
182                }
183            }
184        } else {
185            None
186        };
187
188        match self.replay_cache.lock() {
189            Ok(mut cache) => {
190                if !cache.insert_unique(&req.nonce) {
191                    let status = Status::unauthenticated("replayed request nonce");
192                    emit_rpc_response_log(&self.config, None, &[], Some(&status));
193                    return Err(status);
194                }
195            }
196            Err(_) => {
197                let status = Status::internal("replay cache poisoned");
198                emit_rpc_response_log(&self.config, None, &[], Some(&status));
199                return Err(status);
200            }
201        }
202
203        // Decrypt request
204        let plaintext = match self.cipher.decrypt(&req.nonce, &req.ciphertext) {
205            Ok(plaintext) => plaintext,
206            Err(_) => {
207                emit_rpc_request_log(
208                    &self.config,
209                    None,
210                    serde_json::json!({
211                        "input": serde_json::Value::Null,
212                        "decode_error": "decryption failed",
213                    }),
214                );
215                let status = Status::unauthenticated("decryption failed");
216                emit_rpc_response_log(&self.config, None, &[], Some(&status));
217                return Err(status);
218            }
219        };
220
221        let mut raw_input_value = serde_json::from_slice::<serde_json::Value>(&plaintext)
222            .unwrap_or(serde_json::Value::Null);
223        if let Some(object) = raw_input_value.as_object_mut() {
224            object.remove("id");
225        }
226
227        // Parse Input
228        let input: Input = match serde_json::from_slice(&plaintext) {
229            Ok(input) => input,
230            Err(e) => {
231                emit_rpc_request_log(
232                    &self.config,
233                    None,
234                    serde_json::json!({
235                        "input": raw_input_value,
236                        "decode_error": format!("invalid input: {e}"),
237                    }),
238                );
239                let status = Status::invalid_argument(format!("invalid input: {e}"));
240                emit_rpc_response_log(&self.config, None, &[], Some(&status));
241                return Err(status);
242            }
243        };
244        let request_id = input_request_id(&input).map(|s| s.to_string());
245        emit_rpc_request_log(
246            &self.config,
247            request_id.clone(),
248            serde_json::json!({
249                "input": raw_input_value,
250            }),
251        );
252
253        // Block local-only operations over RPC
254        if input.is_local_only() {
255            let status = Status::permission_denied("local-only operation");
256            emit_rpc_response_log(&self.config, request_id, &[], Some(&status));
257            return Err(status);
258        }
259
260        // Create per-request channel and App
261        let (tx, mut rx) = mpsc::channel::<Output>(256);
262        let store = crate::store::create_storage_backend(&self.config);
263        let app = Arc::new(App::new(self.config.clone(), tx, Some(true), store));
264        app.requests_total.fetch_add(1, Ordering::Relaxed);
265
266        // Dispatch
267        handler::dispatch(&app, input).await;
268
269        // Drop app to close the sender side, then collect all outputs
270        drop(app);
271        let mut outputs = Vec::new();
272        while let Some(out) = rx.recv().await {
273            // Mirror server-side log events to rpc daemon stdout so operators can
274            // observe request flow in long-running rpc mode.
275            if let Output::Log { .. } = &out {
276                let rendered = agent_first_data::cli_output(
277                    &serde_json::to_value(&out).unwrap_or(serde_json::Value::Null),
278                    agent_first_data::OutputFormat::Json,
279                );
280                let _ = writeln!(std::io::stdout(), "{rendered}");
281            }
282            let value = serde_json::to_value(&out).unwrap_or(serde_json::Value::Null);
283            outputs.push(value);
284        }
285
286        // Serialize outputs as JSON array
287        let response_json = match serde_json::to_vec(&outputs) {
288            Ok(response_json) => response_json,
289            Err(e) => {
290                let status = Status::internal(format!("serialize: {e}"));
291                emit_rpc_response_log(&self.config, request_id, &outputs, Some(&status));
292                return Err(status);
293            }
294        };
295
296        // Encrypt response
297        let (nonce, ciphertext) = match self.cipher.encrypt(&response_json) {
298            Ok(payload) => payload,
299            Err(e) => {
300                let status = Status::internal(format!("encrypt: {e}"));
301                emit_rpc_response_log(&self.config, request_id, &outputs, Some(&status));
302                return Err(status);
303            }
304        };
305
306        emit_rpc_response_log(&self.config, request_id, &outputs, None);
307
308        Ok(Response::new(EncryptedResponse { nonce, ciphertext }))
309    }
310}
311
312pub async fn run_rpc(init: RpcInit) {
313    let secret: String = match init
314        .rpc_secret
315        .or_else(|| std::env::var("AFPAY_RPC_SECRET").ok())
316    {
317        Some(s) if !s.is_empty() => s,
318        _ => {
319            let value = agent_first_data::build_cli_error(
320                "--rpc-secret is required for RPC mode",
321                Some("pass a shared secret for client authentication or set AFPAY_RPC_SECRET"),
322            );
323            let rendered =
324                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
325            let _ = writeln!(std::io::stdout(), "{rendered}");
326            std::process::exit(1);
327        }
328    };
329    if let Err(e) = Cipher::validate_secret(&secret) {
330        let value = agent_first_data::build_cli_error(&e, Some("use a random 32+ byte secret"));
331        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
332        let _ = writeln!(std::io::stdout(), "{rendered}");
333        std::process::exit(1);
334    }
335
336    let cipher = Cipher::from_secret(&secret);
337
338    let resolved_dir = init
339        .data_dir
340        .unwrap_or_else(|| RuntimeConfig::default().data_dir);
341    let mut config = match RuntimeConfig::load_from_dir(&resolved_dir) {
342        Ok(c) => c,
343        Err(e) => {
344            let value = agent_first_data::build_cli_error(&e, None);
345            let rendered =
346                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
347            let _ = writeln!(std::io::stdout(), "{rendered}");
348            std::process::exit(1);
349        }
350    };
351    if !init.log.is_empty() {
352        config.log = init.log;
353    }
354
355    // Emit startup log
356    if let Some(startup) = crate::config::maybe_startup_log(
357        &config.log,
358        init.startup_requested,
359        Some(init.startup_argv),
360        Some(&config),
361        init.startup_args,
362    ) {
363        let value = serde_json::to_value(&startup).unwrap_or(serde_json::Value::Null);
364        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
365        let _ = writeln!(std::io::stdout(), "{rendered}");
366    }
367
368    let startup_errors = crate::handler::startup_provider_validation_errors(&config).await;
369    for error_output in &startup_errors {
370        let value = serde_json::to_value(error_output).unwrap_or(serde_json::Value::Null);
371        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
372        let _ = writeln!(std::io::stdout(), "{rendered}");
373    }
374    if !startup_errors.is_empty() {
375        std::process::exit(1);
376    }
377
378    let rate_limiter = config.rate_limit.as_ref().map(RpcRateLimiter::new);
379    let service = AfPayService {
380        cipher,
381        config,
382        rate_limiter,
383        replay_cache: Mutex::new(ReplayCache::new(8192)),
384    };
385
386    let addr = match init.listen.parse() {
387        Ok(a) => a,
388        Err(e) => {
389            let value = agent_first_data::build_cli_error(
390                &format!("invalid --rpc-listen address: {e}"),
391                Some("expected format: host:port (e.g. 127.0.0.1:9100)"),
392            );
393            let rendered =
394                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
395            let _ = writeln!(std::io::stdout(), "{rendered}");
396            std::process::exit(1);
397        }
398    };
399    if public_listen_requires_ack(addr) && !init.allow_public_listen {
400        let value = agent_first_data::build_cli_error(
401            "refusing to bind RPC to a non-loopback address without --public-listen",
402            Some(
403                "use the default 127.0.0.1:9400, or pass --public-listen only behind TLS/firewall",
404            ),
405        );
406        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
407        let _ = writeln!(std::io::stdout(), "{rendered}");
408        std::process::exit(1);
409    }
410
411    let server = tonic::transport::Server::builder()
412        .add_service(AfPayServer::new(service))
413        .serve(addr);
414
415    if let Err(e) = server.await {
416        let value = agent_first_data::build_cli_error(&format!("RPC server error: {e}"), None);
417        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
418        let _ = writeln!(std::io::stdout(), "{rendered}");
419        std::process::exit(1);
420    }
421}
422
423fn public_listen_requires_ack(addr: std::net::SocketAddr) -> bool {
424    !addr.ip().is_loopback()
425}
426
427fn log_event_enabled(log_filters: &[String], event: &str) -> bool {
428    if log_filters.is_empty() {
429        return false;
430    }
431    let ev = event.to_ascii_lowercase();
432    log_filters
433        .iter()
434        .any(|f| f == "*" || f == "all" || ev.starts_with(f.as_str()))
435}
436
437fn emit_rpc_request_log(
438    config: &RuntimeConfig,
439    request_id: Option<String>,
440    args: serde_json::Value,
441) {
442    emit_rpc_log(config, "rpc_request", request_id, args);
443}
444
445fn emit_rpc_response_log(
446    config: &RuntimeConfig,
447    request_id: Option<String>,
448    outputs: &[serde_json::Value],
449    status: Option<&Status>,
450) {
451    let has_output_error = outputs
452        .iter()
453        .any(|item| item.get("code").and_then(|v| v.as_str()) == Some("error"));
454    let mut args = serde_json::json!({
455        "output_count": outputs.len(),
456        "has_error": has_output_error || status.is_some(),
457        "outputs": outputs,
458    });
459    if let Some(status) = status {
460        if let Some(object) = args.as_object_mut() {
461            object.insert(
462                "grpc_error".to_string(),
463                serde_json::json!({
464                    "code": grpc_code_name(status.code()),
465                    "message": status.message(),
466                }),
467            );
468        }
469    }
470    emit_rpc_log(config, "rpc_response", request_id, args);
471}
472
473fn emit_rpc_log(
474    config: &RuntimeConfig,
475    event: &str,
476    request_id: Option<String>,
477    args: serde_json::Value,
478) {
479    if !log_event_enabled(&config.log, event) {
480        return;
481    }
482    let log = Output::Log {
483        event: event.to_string(),
484        request_id,
485        version: None,
486        argv: None,
487        config: None,
488        args: Some(args),
489        env: None,
490        trace: Trace::from_duration(0),
491    };
492    let rendered = agent_first_data::cli_output(
493        &serde_json::to_value(&log).unwrap_or(serde_json::Value::Null),
494        agent_first_data::OutputFormat::Json,
495    );
496    let _ = writeln!(std::io::stdout(), "{rendered}");
497}
498
499fn grpc_code_name(code: Code) -> &'static str {
500    match code {
501        Code::Ok => "ok",
502        Code::Cancelled => "cancelled",
503        Code::Unknown => "unknown",
504        Code::InvalidArgument => "invalid_argument",
505        Code::DeadlineExceeded => "deadline_exceeded",
506        Code::NotFound => "not_found",
507        Code::AlreadyExists => "already_exists",
508        Code::PermissionDenied => "permission_denied",
509        Code::ResourceExhausted => "resource_exhausted",
510        Code::FailedPrecondition => "failed_precondition",
511        Code::Aborted => "aborted",
512        Code::OutOfRange => "out_of_range",
513        Code::Unimplemented => "unimplemented",
514        Code::Internal => "internal",
515        Code::Unavailable => "unavailable",
516        Code::DataLoss => "data_loss",
517        Code::Unauthenticated => "unauthenticated",
518    }
519}
520
521fn input_request_id(input: &Input) -> Option<&str> {
522    match input {
523        Input::WalletCreate { id, .. }
524        | Input::LnWalletCreate { id, .. }
525        | Input::WalletClose { id, .. }
526        | Input::WalletList { id, .. }
527        | Input::Balance { id, .. }
528        | Input::Receive { id, .. }
529        | Input::ReceiveClaim { id, .. }
530        | Input::CashuSend { id, .. }
531        | Input::CashuReceive { id, .. }
532        | Input::Send { id, .. }
533        | Input::Restore { id, .. }
534        | Input::WalletShowSeed { id, .. }
535        | Input::HistoryList { id, .. }
536        | Input::HistoryStatus { id, .. }
537        | Input::HistoryUpdate { id, .. }
538        | Input::LimitAdd { id, .. }
539        | Input::LimitRemove { id, .. }
540        | Input::LimitList { id, .. }
541        | Input::LimitSet { id, .. }
542        | Input::WalletConfigShow { id, .. }
543        | Input::WalletConfigSet { id, .. }
544        | Input::WalletConfigTokenAdd { id, .. }
545        | Input::WalletConfigTokenRemove { id, .. } => Some(id.as_str()),
546        Input::Config(_) | Input::ConfigShow { .. } | Input::Version | Input::Close => None,
547    }
548}