Skip to main content

agent_first_pay/rpc/
mod.rs

1pub mod crypto;
2
3use crate::handler::{self, App};
4use crate::rpc::crypto::Cipher;
5use crate::types::*;
6use std::io::Write;
7use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tonic::Code;
11use tonic::{Request, Response, Status};
12
13pub struct RpcInit {
14    pub listen: String,
15    pub rpc_secret: Option<String>,
16    pub log: Vec<String>,
17    pub data_dir: Option<String>,
18    pub startup_argv: Vec<String>,
19    pub startup_args: serde_json::Value,
20    pub startup_requested: bool,
21}
22
23pub mod proto {
24    tonic::include_proto!("afpay");
25}
26
27use proto::af_pay_server::{AfPay, AfPayServer};
28use proto::{EncryptedRequest, EncryptedResponse};
29
30struct AfPayService {
31    cipher: Cipher,
32    config: RuntimeConfig,
33    rate_limiter: Option<RpcRateLimiter>,
34}
35
36/// Simple token-bucket rate limiter for RPC.
37struct RpcRateLimiter {
38    rps: u32,
39    max_concurrent: u32,
40    in_flight: AtomicU32,
41    tokens_milli: AtomicU64,
42    last_refill_ms: AtomicU64,
43}
44
45impl RpcRateLimiter {
46    fn new(config: &RateLimitConfig) -> Self {
47        let rps = config.requests_per_second;
48        Self {
49            rps,
50            max_concurrent: config.max_concurrent,
51            in_flight: AtomicU32::new(0),
52            tokens_milli: AtomicU64::new(u64::from(rps) * 1000),
53            last_refill_ms: AtomicU64::new(rpc_now_ms()),
54        }
55    }
56
57    fn try_acquire(&self) -> Result<RpcRateLimitGuard<'_>, ()> {
58        if self.max_concurrent > 0 {
59            let prev = self.in_flight.fetch_add(1, Ordering::Relaxed);
60            if prev >= self.max_concurrent {
61                self.in_flight.fetch_sub(1, Ordering::Relaxed);
62                return Err(());
63            }
64        }
65        if self.rps > 0 {
66            self.refill();
67            let cost = 1000u64;
68            loop {
69                let current = self.tokens_milli.load(Ordering::Relaxed);
70                if current < cost {
71                    if self.max_concurrent > 0 {
72                        self.in_flight.fetch_sub(1, Ordering::Relaxed);
73                    }
74                    return Err(());
75                }
76                if self
77                    .tokens_milli
78                    .compare_exchange_weak(
79                        current,
80                        current - cost,
81                        Ordering::Relaxed,
82                        Ordering::Relaxed,
83                    )
84                    .is_ok()
85                {
86                    break;
87                }
88            }
89        }
90        Ok(RpcRateLimitGuard { limiter: self })
91    }
92
93    fn refill(&self) {
94        let now = rpc_now_ms();
95        let last = self.last_refill_ms.load(Ordering::Relaxed);
96        let elapsed = now.saturating_sub(last);
97        if elapsed == 0 {
98            return;
99        }
100        if self
101            .last_refill_ms
102            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
103            .is_ok()
104        {
105            let add = elapsed * u64::from(self.rps);
106            let max = u64::from(self.rps) * 1000;
107            let _ = self
108                .tokens_milli
109                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |c| {
110                    Some(c.saturating_add(add).min(max))
111                });
112        }
113    }
114}
115
116struct RpcRateLimitGuard<'a> {
117    limiter: &'a RpcRateLimiter,
118}
119
120impl Drop for RpcRateLimitGuard<'_> {
121    fn drop(&mut self) {
122        if self.limiter.max_concurrent > 0 {
123            self.limiter.in_flight.fetch_sub(1, Ordering::Relaxed);
124        }
125    }
126}
127
128fn rpc_now_ms() -> u64 {
129    std::time::SystemTime::now()
130        .duration_since(std::time::UNIX_EPOCH)
131        .map(|d| d.as_millis() as u64)
132        .unwrap_or(0)
133}
134
135#[tonic::async_trait]
136impl AfPay for AfPayService {
137    async fn call(
138        &self,
139        request: Request<EncryptedRequest>,
140    ) -> Result<Response<EncryptedResponse>, Status> {
141        let req = request.into_inner();
142
143        // Rate limit check
144        let _rate_guard = if let Some(rl) = &self.rate_limiter {
145            match rl.try_acquire() {
146                Ok(guard) => Some(guard),
147                Err(()) => {
148                    return Err(Status::resource_exhausted("rate limit exceeded"));
149                }
150            }
151        } else {
152            None
153        };
154
155        // Decrypt request
156        let plaintext = match self.cipher.decrypt(&req.nonce, &req.ciphertext) {
157            Ok(plaintext) => plaintext,
158            Err(_) => {
159                emit_rpc_request_log(
160                    &self.config,
161                    None,
162                    serde_json::json!({
163                        "input": serde_json::Value::Null,
164                        "decode_error": "decryption failed",
165                    }),
166                );
167                let status = Status::unauthenticated("decryption failed");
168                emit_rpc_response_log(&self.config, None, &[], Some(&status));
169                return Err(status);
170            }
171        };
172
173        let mut raw_input_value = serde_json::from_slice::<serde_json::Value>(&plaintext)
174            .unwrap_or(serde_json::Value::Null);
175        if let Some(object) = raw_input_value.as_object_mut() {
176            object.remove("id");
177        }
178
179        // Parse Input
180        let input: Input = match serde_json::from_slice(&plaintext) {
181            Ok(input) => input,
182            Err(e) => {
183                emit_rpc_request_log(
184                    &self.config,
185                    None,
186                    serde_json::json!({
187                        "input": raw_input_value,
188                        "decode_error": format!("invalid input: {e}"),
189                    }),
190                );
191                let status = Status::invalid_argument(format!("invalid input: {e}"));
192                emit_rpc_response_log(&self.config, None, &[], Some(&status));
193                return Err(status);
194            }
195        };
196        let request_id = input_request_id(&input).map(|s| s.to_string());
197        emit_rpc_request_log(
198            &self.config,
199            request_id.clone(),
200            serde_json::json!({
201                "input": raw_input_value,
202            }),
203        );
204
205        // Block local-only operations over RPC
206        if input.is_local_only() {
207            let status = Status::permission_denied("local-only operation");
208            emit_rpc_response_log(&self.config, request_id, &[], Some(&status));
209            return Err(status);
210        }
211
212        // Create per-request channel and App
213        let (tx, mut rx) = mpsc::channel::<Output>(256);
214        let store = crate::store::create_storage_backend(&self.config);
215        let app = Arc::new(App::new(self.config.clone(), tx, Some(true), store));
216        app.requests_total.fetch_add(1, Ordering::Relaxed);
217
218        // Dispatch
219        handler::dispatch(&app, input).await;
220
221        // Drop app to close the sender side, then collect all outputs
222        drop(app);
223        let mut outputs = Vec::new();
224        while let Some(out) = rx.recv().await {
225            // Mirror server-side log events to rpc daemon stdout so operators can
226            // observe request flow in long-running rpc mode.
227            if let Output::Log { .. } = &out {
228                let rendered = agent_first_data::cli_output(
229                    &serde_json::to_value(&out).unwrap_or(serde_json::Value::Null),
230                    agent_first_data::OutputFormat::Json,
231                );
232                let _ = writeln!(std::io::stdout(), "{rendered}");
233            }
234            let value = serde_json::to_value(&out).unwrap_or(serde_json::Value::Null);
235            outputs.push(value);
236        }
237
238        // Serialize outputs as JSON array
239        let response_json = match serde_json::to_vec(&outputs) {
240            Ok(response_json) => response_json,
241            Err(e) => {
242                let status = Status::internal(format!("serialize: {e}"));
243                emit_rpc_response_log(&self.config, request_id, &outputs, Some(&status));
244                return Err(status);
245            }
246        };
247
248        // Encrypt response
249        let (nonce, ciphertext) = match self.cipher.encrypt(&response_json) {
250            Ok(payload) => payload,
251            Err(e) => {
252                let status = Status::internal(format!("encrypt: {e}"));
253                emit_rpc_response_log(&self.config, request_id, &outputs, Some(&status));
254                return Err(status);
255            }
256        };
257
258        emit_rpc_response_log(&self.config, request_id, &outputs, None);
259
260        Ok(Response::new(EncryptedResponse { nonce, ciphertext }))
261    }
262}
263
264pub async fn run_rpc(init: RpcInit) {
265    let secret: String = match init.rpc_secret {
266        Some(s) if !s.is_empty() => s,
267        _ => {
268            let value = agent_first_data::build_cli_error(
269                "--rpc-secret is required for RPC mode",
270                Some("pass a shared secret for client authentication"),
271            );
272            let rendered =
273                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
274            let _ = writeln!(std::io::stdout(), "{rendered}");
275            std::process::exit(1);
276        }
277    };
278
279    let cipher = Cipher::from_secret(&secret);
280
281    let resolved_dir = init
282        .data_dir
283        .unwrap_or_else(|| RuntimeConfig::default().data_dir);
284    let mut config = match RuntimeConfig::load_from_dir(&resolved_dir) {
285        Ok(c) => c,
286        Err(e) => {
287            let value = agent_first_data::build_cli_error(&e, None);
288            let rendered =
289                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
290            let _ = writeln!(std::io::stdout(), "{rendered}");
291            std::process::exit(1);
292        }
293    };
294    if !init.log.is_empty() {
295        config.log = init.log;
296    }
297
298    // Emit startup log
299    if let Some(startup) = crate::config::maybe_startup_log(
300        &config.log,
301        init.startup_requested,
302        Some(init.startup_argv),
303        Some(&config),
304        init.startup_args,
305    ) {
306        let value = serde_json::to_value(&startup).unwrap_or(serde_json::Value::Null);
307        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
308        let _ = writeln!(std::io::stdout(), "{rendered}");
309    }
310
311    let startup_errors = crate::handler::startup_provider_validation_errors(&config).await;
312    for error_output in &startup_errors {
313        let value = serde_json::to_value(error_output).unwrap_or(serde_json::Value::Null);
314        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
315        let _ = writeln!(std::io::stdout(), "{rendered}");
316    }
317    if !startup_errors.is_empty() {
318        std::process::exit(1);
319    }
320
321    let rate_limiter = config.rate_limit.as_ref().map(RpcRateLimiter::new);
322    let service = AfPayService {
323        cipher,
324        config,
325        rate_limiter,
326    };
327
328    let addr = match init.listen.parse() {
329        Ok(a) => a,
330        Err(e) => {
331            let value = agent_first_data::build_cli_error(
332                &format!("invalid --rpc-listen address: {e}"),
333                Some("expected format: host:port (e.g. 127.0.0.1:9100)"),
334            );
335            let rendered =
336                agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
337            let _ = writeln!(std::io::stdout(), "{rendered}");
338            std::process::exit(1);
339        }
340    };
341
342    let server = tonic::transport::Server::builder()
343        .add_service(AfPayServer::new(service))
344        .serve(addr);
345
346    if let Err(e) = server.await {
347        let value = agent_first_data::build_cli_error(&format!("RPC server error: {e}"), None);
348        let rendered = agent_first_data::cli_output(&value, agent_first_data::OutputFormat::Json);
349        let _ = writeln!(std::io::stdout(), "{rendered}");
350        std::process::exit(1);
351    }
352}
353
354fn log_event_enabled(log_filters: &[String], event: &str) -> bool {
355    if log_filters.is_empty() {
356        return false;
357    }
358    let ev = event.to_ascii_lowercase();
359    log_filters
360        .iter()
361        .any(|f| f == "*" || f == "all" || ev.starts_with(f.as_str()))
362}
363
364fn emit_rpc_request_log(
365    config: &RuntimeConfig,
366    request_id: Option<String>,
367    args: serde_json::Value,
368) {
369    emit_rpc_log(config, "rpc_request", request_id, args);
370}
371
372fn emit_rpc_response_log(
373    config: &RuntimeConfig,
374    request_id: Option<String>,
375    outputs: &[serde_json::Value],
376    status: Option<&Status>,
377) {
378    let has_output_error = outputs
379        .iter()
380        .any(|item| item.get("code").and_then(|v| v.as_str()) == Some("error"));
381    let mut args = serde_json::json!({
382        "output_count": outputs.len(),
383        "has_error": has_output_error || status.is_some(),
384        "outputs": outputs,
385    });
386    if let Some(status) = status {
387        if let Some(object) = args.as_object_mut() {
388            object.insert(
389                "grpc_error".to_string(),
390                serde_json::json!({
391                    "code": grpc_code_name(status.code()),
392                    "message": status.message(),
393                }),
394            );
395        }
396    }
397    emit_rpc_log(config, "rpc_response", request_id, args);
398}
399
400fn emit_rpc_log(
401    config: &RuntimeConfig,
402    event: &str,
403    request_id: Option<String>,
404    args: serde_json::Value,
405) {
406    if !log_event_enabled(&config.log, event) {
407        return;
408    }
409    let log = Output::Log {
410        event: event.to_string(),
411        request_id,
412        version: None,
413        argv: None,
414        config: None,
415        args: Some(args),
416        env: None,
417        trace: Trace::from_duration(0),
418    };
419    let rendered = agent_first_data::cli_output(
420        &serde_json::to_value(&log).unwrap_or(serde_json::Value::Null),
421        agent_first_data::OutputFormat::Json,
422    );
423    let _ = writeln!(std::io::stdout(), "{rendered}");
424}
425
426fn grpc_code_name(code: Code) -> &'static str {
427    match code {
428        Code::Ok => "ok",
429        Code::Cancelled => "cancelled",
430        Code::Unknown => "unknown",
431        Code::InvalidArgument => "invalid_argument",
432        Code::DeadlineExceeded => "deadline_exceeded",
433        Code::NotFound => "not_found",
434        Code::AlreadyExists => "already_exists",
435        Code::PermissionDenied => "permission_denied",
436        Code::ResourceExhausted => "resource_exhausted",
437        Code::FailedPrecondition => "failed_precondition",
438        Code::Aborted => "aborted",
439        Code::OutOfRange => "out_of_range",
440        Code::Unimplemented => "unimplemented",
441        Code::Internal => "internal",
442        Code::Unavailable => "unavailable",
443        Code::DataLoss => "data_loss",
444        Code::Unauthenticated => "unauthenticated",
445    }
446}
447
448fn input_request_id(input: &Input) -> Option<&str> {
449    match input {
450        Input::WalletCreate { id, .. }
451        | Input::LnWalletCreate { id, .. }
452        | Input::WalletClose { id, .. }
453        | Input::WalletList { id, .. }
454        | Input::Balance { id, .. }
455        | Input::Receive { id, .. }
456        | Input::ReceiveClaim { id, .. }
457        | Input::CashuSend { id, .. }
458        | Input::CashuReceive { id, .. }
459        | Input::Send { id, .. }
460        | Input::Restore { id, .. }
461        | Input::WalletShowSeed { id, .. }
462        | Input::HistoryList { id, .. }
463        | Input::HistoryStatus { id, .. }
464        | Input::HistoryUpdate { id, .. }
465        | Input::LimitAdd { id, .. }
466        | Input::LimitRemove { id, .. }
467        | Input::LimitList { id, .. }
468        | Input::LimitSet { id, .. }
469        | Input::WalletConfigShow { id, .. }
470        | Input::WalletConfigSet { id, .. }
471        | Input::WalletConfigTokenAdd { id, .. }
472        | Input::WalletConfigTokenRemove { id, .. } => Some(id.as_str()),
473        Input::Config(_) | Input::Version | Input::Close => None,
474    }
475}