Skip to main content

tauri_plugin_conduit/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3//! # tauri-plugin-conduit
4//!
5//! Tauri v2 plugin for conduit — binary IPC over the `conduit://` custom
6//! protocol.
7//!
8//! Registers a `conduit://` custom protocol for zero-overhead in-process
9//! binary dispatch. Supports both sync and async handlers via
10//! [`ConduitHandler`](conduit_core::ConduitHandler). No network surface.
11//!
12//! ## Usage
13//!
14//! ```rust,ignore
15//! use tauri_conduit::{command, handler};
16//!
17//! #[command]
18//! fn greet(name: String) -> String {
19//!     format!("Hello, {name}!")
20//! }
21//!
22//! #[command]
23//! async fn fetch_user(state: State<'_, Db>, id: u64) -> Result<User, String> {
24//!     state.get_user(id).await.map_err(|e| e.to_string())
25//! }
26//!
27//! tauri::Builder::default()
28//!     .plugin(
29//!         tauri_plugin_conduit::init()
30//!             .handler("greet", handler!(greet))
31//!             .handler("fetch_user", handler!(fetch_user))
32//!             .channel("telemetry")
33//!             .build()
34//!     )
35//!     .run(tauri::generate_context!())
36//!     .unwrap();
37//! ```
38
39/// Re-export the `#[command]` attribute macro from `conduit-derive`.
40///
41/// This is conduit's equivalent of `#[tauri::command]`. Use it for
42/// named-parameter handlers:
43///
44/// ```rust,ignore
45/// use tauri_conduit::{command, handler};
46///
47/// #[command]
48/// fn greet(name: String, greeting: String) -> String {
49///     format!("{greeting}, {name}!")
50/// }
51/// ```
52pub use conduit_derive::command;
53
54/// Re-export the `handler!()` macro from `conduit-derive`.
55///
56/// Resolves a `#[command]` function name to its conduit handler struct
57/// for registration:
58///
59/// ```rust,ignore
60/// tauri_plugin_conduit::init()
61///     .handler("greet", handler!(greet))
62///     .build()
63/// ```
64pub use conduit_derive::handler;
65
66use std::collections::HashMap;
67use std::sync::Arc;
68
69use conduit_core::{
70    ChannelBuffer, ConduitHandler, Decode, Encode, HandlerResponse, Queue, RingBuffer, Router,
71};
72use futures_util::FutureExt;
73use subtle::ConstantTimeEq;
74use tauri::plugin::{Builder as TauriPluginBuilder, TauriPlugin};
75use tauri::{AppHandle, Emitter, Manager, Runtime};
76
77// ---------------------------------------------------------------------------
78// Helper: safe HTTP response builder
79// ---------------------------------------------------------------------------
80
81/// Build an HTTP response, falling back to a minimal 500 if construction fails.
82fn make_response(status: u16, content_type: &str, body: Vec<u8>) -> http::Response<Vec<u8>> {
83    http::Response::builder()
84        .status(status)
85        .header("Content-Type", content_type)
86        .header("Access-Control-Allow-Origin", "*")
87        .body(body)
88        .unwrap_or_else(|_| {
89            http::Response::builder()
90                .status(500)
91                .body(b"internal error".to_vec())
92                .expect("fallback response must not fail")
93        })
94}
95
96/// Build a JSON error response: `{"error": "message"}`.
97///
98/// Uses `sonic_rs` for proper RFC 8259 escaping of all control characters,
99/// newlines, quotes, and backslashes — not just `\` and `"`.
100fn make_error_response(status: u16, message: &str) -> http::Response<Vec<u8>> {
101    #[derive(serde::Serialize)]
102    struct ErrorBody<'a> {
103        error: &'a str,
104    }
105    let body = conduit_core::sonic_rs::to_vec(&ErrorBody { error: message })
106        .unwrap_or_else(|_| br#"{"error":"internal error"}"#.to_vec());
107    make_response(status, "application/json", body)
108}
109
110// ---------------------------------------------------------------------------
111// BootstrapInfo — returned to JS via `conduit_bootstrap` command
112// ---------------------------------------------------------------------------
113
114/// Connection info returned to the frontend during bootstrap.
115#[derive(Clone, serde::Serialize, serde::Deserialize)]
116#[serde(rename_all = "camelCase")]
117pub struct BootstrapInfo {
118    /// Protocol version (currently `1`). Allows the TS client to verify
119    /// protocol compatibility.
120    #[serde(default = "default_protocol_version")]
121    pub protocol_version: u8,
122    /// Base URL for the custom protocol (e.g., `"conduit://localhost"`).
123    pub protocol_base: String,
124    /// Per-launch invoke key for custom protocol authentication (hex-encoded).
125    ///
126    /// **Security**: This key authenticates custom protocol requests. It is
127    /// generated fresh each launch from 32 bytes of OS randomness and validated
128    /// using constant-time comparison. The JS client includes it as the
129    /// `X-Conduit-Key` header on every `conduit://` request.
130    pub invoke_key: String,
131    /// Available channel names.
132    pub channels: Vec<String>,
133}
134
135fn default_protocol_version() -> u8 {
136    1
137}
138
139impl std::fmt::Debug for BootstrapInfo {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("BootstrapInfo")
142            .field("protocol_version", &self.protocol_version)
143            .field("protocol_base", &self.protocol_base)
144            .field("invoke_key", &"[REDACTED]")
145            .field("channels", &self.channels)
146            .finish()
147    }
148}
149
150// ---------------------------------------------------------------------------
151// PluginState — managed Tauri state
152// ---------------------------------------------------------------------------
153
154/// Shared state for the conduit Tauri plugin.
155///
156/// Holds the router, named streaming channels, the per-launch invoke key,
157/// and the app handle for emitting push notifications.
158pub struct PluginState<R: Runtime> {
159    dispatch: Arc<Router>,
160    /// `#[command]`-generated handlers (sync and async via [`ConduitHandler`]).
161    handlers: Arc<HashMap<String, Arc<dyn ConduitHandler>>>,
162    /// Named channels for server→client streaming (lossy or ordered).
163    channels: HashMap<String, Arc<ChannelBuffer>>,
164    /// Tauri app handle for emitting events to the frontend.
165    app_handle: AppHandle<R>,
166    /// Pre-cached `Arc` of the app handle — avoids a heap allocation per request.
167    app_handle_arc: Arc<AppHandle<R>>,
168    /// Per-launch invoke key (hex-encoded, 64 hex chars = 32 bytes).
169    invoke_key: String,
170    /// Raw invoke key bytes for constant-time comparison.
171    invoke_key_bytes: [u8; 32],
172}
173
174impl<R: Runtime> std::fmt::Debug for PluginState<R> {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        f.debug_struct("PluginState")
177            .field("channels", &self.channels.keys().collect::<Vec<_>>())
178            .field("invoke_key", &"[REDACTED]")
179            .finish()
180    }
181}
182
183impl<R: Runtime> PluginState<R> {
184    /// Get a channel by name (for pushing data from Rust handlers).
185    pub fn channel(&self, name: &str) -> Option<&Arc<ChannelBuffer>> {
186        self.channels.get(name)
187    }
188
189    /// Push binary data to a named channel and notify JS listeners.
190    ///
191    /// After writing to the channel, emits both a global
192    /// `conduit:data-available` event (payload = channel name) and a
193    /// per-channel `conduit:data-available:{channel}` event. JS subscribers
194    /// can listen on either.
195    ///
196    /// For lossy channels, oldest frames are silently dropped when the buffer
197    /// is full. For reliable channels, returns an error if the buffer is full
198    /// (backpressure).
199    ///
200    /// Returns an error string if the named channel was not registered via
201    /// the builder or if a reliable channel is full.
202    pub fn push(&self, channel: &str, data: &[u8]) -> Result<(), String> {
203        let ch = self
204            .channels
205            .get(channel)
206            .ok_or_else(|| format!("unknown channel: {channel}"))?;
207        ch.push(data).map(|_| ()).map_err(|e| e.to_string())?;
208        // Emit global event (backward-compatible with old JS code).
209        if self
210            .app_handle
211            .emit("conduit:data-available", channel)
212            .is_err()
213        {
214            #[cfg(debug_assertions)]
215            eprintln!(
216                "conduit: failed to emit global data-available event for channel '{channel}'"
217            );
218        }
219        // Emit per-channel event.
220        if self
221            .app_handle
222            .emit(&format!("conduit:data-available:{channel}"), channel)
223            .is_err()
224        {
225            #[cfg(debug_assertions)]
226            eprintln!(
227                "conduit: failed to emit per-channel data-available event for channel '{channel}'"
228            );
229        }
230        Ok(())
231    }
232
233    /// Return the list of registered channel names.
234    pub fn channel_names(&self) -> Vec<String> {
235        self.channels.keys().cloned().collect()
236    }
237
238    /// Validate an invoke key candidate using constant-time operations.
239    fn validate_invoke_key(&self, candidate: &str) -> bool {
240        validate_invoke_key_ct(&self.invoke_key_bytes, candidate)
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Tauri commands
246// ---------------------------------------------------------------------------
247
248/// Return bootstrap info so the JS client knows how to reach the conduit
249/// custom protocol.
250///
251/// May be called multiple times (e.g., after page reloads during development).
252/// The invoke key is generated once at plugin setup and remains constant for
253/// the lifetime of the app process. Repeated calls return the same key.
254#[tauri::command]
255fn conduit_bootstrap(
256    state: tauri::State<'_, PluginState<tauri::Wry>>,
257) -> Result<BootstrapInfo, String> {
258    Ok(BootstrapInfo {
259        protocol_version: 1,
260        protocol_base: "conduit://localhost".to_string(),
261        invoke_key: state.invoke_key.clone(),
262        channels: state.channel_names(),
263    })
264}
265
266/// Validate channel names and return those that exist.
267///
268/// This is a validation-only endpoint — no server-side subscription state is
269/// tracked. The JS client uses the returned list to know which channels are
270/// available. Actual data delivery happens via `conduit:data-available` events
271/// and `conduit://localhost/drain/<channel>` protocol requests.
272///
273/// Unknown channel names are silently filtered out — only channels that
274/// exist are returned.
275#[tauri::command]
276fn conduit_subscribe(
277    state: tauri::State<'_, PluginState<tauri::Wry>>,
278    channels: Vec<String>,
279) -> Result<Vec<String>, String> {
280    // Silently filter to only channels that exist.
281    let valid: Vec<String> = channels
282        .into_iter()
283        .filter(|c| state.channels.contains_key(c.as_str()))
284        .collect();
285    Ok(valid)
286}
287
288// ---------------------------------------------------------------------------
289// Channel kind (internal)
290// ---------------------------------------------------------------------------
291
292/// Internal enum for deferred channel construction.
293enum ChannelKind {
294    /// Lossy ring buffer with the given byte capacity.
295    Lossy(usize),
296    /// Reliable queue with the given max byte limit.
297    Reliable(usize),
298}
299
300// ---------------------------------------------------------------------------
301// Plugin builder
302// ---------------------------------------------------------------------------
303
304/// A deferred command registration closure.
305type CommandRegistration = Box<dyn FnOnce(&Router) + Send>;
306
307/// Builder for the conduit Tauri v2 plugin.
308///
309/// Collects command registrations and configuration, then produces a
310/// [`TauriPlugin`] via [`build`](Self::build).
311pub struct PluginBuilder {
312    /// Deferred command registrations: (name, handler factory).
313    commands: Vec<CommandRegistration>,
314    /// `#[command]`-generated handlers (sync and async).
315    handler_defs: Vec<(String, Arc<dyn ConduitHandler>)>,
316    /// Named channels: (name, kind).
317    channel_defs: Vec<(String, ChannelKind)>,
318}
319
320impl std::fmt::Debug for PluginBuilder {
321    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322        f.debug_struct("PluginBuilder")
323            .field("commands", &self.commands.len())
324            .field("handlers", &self.handler_defs.len())
325            .field("channel_defs_count", &self.channel_defs.len())
326            .finish()
327    }
328}
329
330/// Validate that a channel name matches `[a-zA-Z0-9_-]+`.
331fn validate_channel_name(name: &str) {
332    assert!(
333        !name.is_empty()
334            && name
335                .bytes()
336                .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-'),
337        "conduit: invalid channel name '{}' — must match [a-zA-Z0-9_-]+",
338        name
339    );
340}
341
342/// Default channel capacity (64 KB).
343const DEFAULT_CHANNEL_CAPACITY: usize = 64 * 1024;
344
345impl PluginBuilder {
346    /// Panic if a channel with the given name is already registered.
347    fn assert_no_duplicate_channel(&self, name: &str) {
348        if self.channel_defs.iter().any(|(n, _)| n == name) {
349            panic!(
350                "conduit: duplicate channel name '{}' — each channel must have a unique name",
351                name
352            );
353        }
354    }
355
356    /// Create a new, empty plugin builder.
357    pub fn new() -> Self {
358        Self {
359            commands: Vec::new(),
360            handler_defs: Vec::new(),
361            channel_defs: Vec::new(),
362        }
363    }
364
365    // -- Raw handlers -------------------------------------------------------
366
367    /// Register a raw command handler (`Vec<u8>` in, `Vec<u8>` out).
368    ///
369    /// Command names correspond to the path segment in the
370    /// `conduit://localhost/invoke/<cmd_name>` URL.
371    pub fn command<F>(mut self, name: impl Into<String>, handler: F) -> Self
372    where
373        F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
374    {
375        let name = name.into();
376        self.commands.push(Box::new(move |table: &Router| {
377            table.register(name, handler);
378        }));
379        self
380    }
381
382    // -- ConduitHandler-based (#[command]-generated, sync or async) ----------
383
384    /// Register a `#[tauri_conduit::command]`-generated handler.
385    ///
386    /// Works with both sync and async handlers. Sync handlers are dispatched
387    /// inline. Async handlers are spawned on the tokio runtime — truly async,
388    /// exactly like `#[tauri::command]`.
389    ///
390    /// ```rust,ignore
391    /// use tauri_conduit::{command, handler};
392    ///
393    /// #[command]
394    /// fn greet(name: String) -> String {
395    ///     format!("Hello, {name}!")
396    /// }
397    ///
398    /// #[command]
399    /// async fn fetch_user(state: State<'_, Db>, id: u64) -> Result<User, String> {
400    ///     state.get_user(id).await.map_err(|e| e.to_string())
401    /// }
402    ///
403    /// tauri_plugin_conduit::init()
404    ///     .handler("greet", handler!(greet))
405    ///     .handler("fetch_user", handler!(fetch_user))
406    ///     .build()
407    /// ```
408    pub fn handler(mut self, name: impl Into<String>, handler: impl ConduitHandler) -> Self {
409        self.handler_defs.push((name.into(), Arc::new(handler)));
410        self
411    }
412
413    /// Register a raw closure handler (legacy API).
414    ///
415    /// Accepts the same closure signature as the pre-`ConduitHandler` `.handler()`:
416    /// `Fn(Vec<u8>, &dyn Any) -> Result<Vec<u8>, Error>`. This is a synchronous
417    /// handler dispatched via `Router::register_with_context`.
418    ///
419    /// Use this for backward compatibility when migrating from closure-based
420    /// registration. For new code, prefer [`handler`](Self::handler) with
421    /// `#[tauri_conduit::command]` + `handler!()`.
422    pub fn handler_raw<F>(mut self, name: impl Into<String>, handler: F) -> Self
423    where
424        F: Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, conduit_core::Error>
425            + Send
426            + Sync
427            + 'static,
428    {
429        let name = name.into();
430        self.commands.push(Box::new(move |table: &Router| {
431            table.register_with_context(name, handler);
432        }));
433        self
434    }
435
436    // -- JSON handlers (Level 1) --------------------------------------------
437
438    /// Typed JSON handler. Deserializes the request payload as `A` and
439    /// serializes the response as `R`.
440    ///
441    /// Unlike Tauri's `#[tauri::command]`, this takes a single argument type
442    /// (not named parameters) and does not support async or State injection.
443    ///
444    /// ```rust,ignore
445    /// .command_json("greet", |name: String| format!("Hello, {name}!"))
446    /// ```
447    pub fn command_json<F, A, R>(mut self, name: impl Into<String>, handler: F) -> Self
448    where
449        F: Fn(A) -> R + Send + Sync + 'static,
450        A: serde::de::DeserializeOwned + 'static,
451        R: serde::Serialize + 'static,
452    {
453        let name = name.into();
454        self.commands.push(Box::new(move |table: &Router| {
455            table.register_json(name, handler);
456        }));
457        self
458    }
459
460    /// Typed JSON handler that returns `Result<R, E>`.
461    ///
462    /// Like [`command_json`](Self::command_json), but the handler returns
463    /// `Result<R, E>` where `E: Display`. On success, `R` is serialized to
464    /// JSON. On error, the error's `Display` text is returned to the caller.
465    ///
466    /// For Tauri-style named parameters with `Result` returns, prefer
467    /// [`handler`](Self::handler) with `#[tauri_conduit::command]` instead:
468    ///
469    /// ```rust,ignore
470    /// use tauri_conduit::command;
471    ///
472    /// #[command]
473    /// fn divide(a: f64, b: f64) -> Result<f64, String> {
474    ///     if b == 0.0 { Err("division by zero".into()) }
475    ///     else { Ok(a / b) }
476    /// }
477    ///
478    /// // Preferred:
479    /// .handler("divide", divide)
480    /// ```
481    pub fn command_json_result<F, A, R, E>(mut self, name: impl Into<String>, handler: F) -> Self
482    where
483        F: Fn(A) -> Result<R, E> + Send + Sync + 'static,
484        A: serde::de::DeserializeOwned + 'static,
485        R: serde::Serialize + 'static,
486        E: std::fmt::Display + 'static,
487    {
488        let name = name.into();
489        self.commands.push(Box::new(move |table: &Router| {
490            table.register_json_result(name, handler);
491        }));
492        self
493    }
494
495    // -- Binary handlers (Level 2) ------------------------------------------
496
497    /// Register a typed binary command handler.
498    ///
499    /// The request payload is decoded via the [`Decode`] trait and the response
500    /// is encoded via [`Encode`]. No JSON involved — raw bytes in, raw bytes
501    /// out.
502    ///
503    /// ```rust,ignore
504    /// .command_binary("process", |tick: MarketTick| tick)
505    /// ```
506    pub fn command_binary<F, A, Ret>(mut self, name: impl Into<String>, handler: F) -> Self
507    where
508        F: Fn(A) -> Ret + Send + Sync + 'static,
509        A: Decode + 'static,
510        Ret: Encode + 'static,
511    {
512        let name = name.into();
513        self.commands.push(Box::new(move |table: &Router| {
514            table.register_binary(name, handler);
515        }));
516        self
517    }
518
519    // -- Lossy channels (default) -------------------------------------------
520
521    /// Register a lossy channel with the default capacity (64 KB).
522    ///
523    /// Oldest frames are silently dropped when the buffer is full. Best for
524    /// telemetry, game state, and real-time data where freshness matters more
525    /// than completeness.
526    ///
527    /// # Panics
528    ///
529    /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
530    /// or duplicates an already-registered channel name.
531    pub fn channel(mut self, name: impl Into<String>) -> Self {
532        let name = name.into();
533        validate_channel_name(&name);
534        self.assert_no_duplicate_channel(&name);
535        self.channel_defs
536            .push((name, ChannelKind::Lossy(DEFAULT_CHANNEL_CAPACITY)));
537        self
538    }
539
540    /// Register a lossy channel with a custom byte capacity.
541    ///
542    /// # Panics
543    ///
544    /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
545    /// or duplicates an already-registered channel name.
546    pub fn channel_with_capacity(mut self, name: impl Into<String>, capacity: usize) -> Self {
547        let name = name.into();
548        validate_channel_name(&name);
549        self.assert_no_duplicate_channel(&name);
550        self.channel_defs.push((name, ChannelKind::Lossy(capacity)));
551        self
552    }
553
554    // -- Reliable channels (guaranteed delivery) ----------------------------
555
556    /// Register an ordered channel with the default capacity (64 KB).
557    ///
558    /// No frames are ever dropped. When the buffer is full,
559    /// [`PluginState::push`] returns an error (backpressure). Best for
560    /// transaction logs, control messages, and any data that must arrive
561    /// intact and in order.
562    ///
563    /// # Panics
564    ///
565    /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
566    /// or duplicates an already-registered channel name.
567    pub fn channel_ordered(mut self, name: impl Into<String>) -> Self {
568        let name = name.into();
569        validate_channel_name(&name);
570        self.assert_no_duplicate_channel(&name);
571        self.channel_defs
572            .push((name, ChannelKind::Reliable(DEFAULT_CHANNEL_CAPACITY)));
573        self
574    }
575
576    /// Register an ordered channel with a custom byte limit.
577    ///
578    /// A `max_bytes` of `0` means unbounded — the buffer grows without limit.
579    ///
580    /// # Panics
581    ///
582    /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
583    /// or duplicates an already-registered channel name.
584    pub fn channel_ordered_with_capacity(
585        mut self,
586        name: impl Into<String>,
587        max_bytes: usize,
588    ) -> Self {
589        let name = name.into();
590        validate_channel_name(&name);
591        self.assert_no_duplicate_channel(&name);
592        self.channel_defs
593            .push((name, ChannelKind::Reliable(max_bytes)));
594        self
595    }
596
597    // -- Build --------------------------------------------------------------
598
599    /// Build the Tauri v2 plugin.
600    ///
601    /// This consumes the builder and returns a [`TauriPlugin`] that can be
602    /// passed to `tauri::Builder::plugin`.
603    ///
604    /// # Dispatch model
605    ///
606    /// Commands are dispatched through a two-tier system:
607    ///
608    /// 1. **`#[command]` handlers** (registered via [`.handler()`](Self::handler))
609    ///    are checked first. These support named parameters, `State<T>` injection,
610    ///    `Result` returns, and async — full parity with `#[tauri::command]`.
611    ///
612    /// 2. **Raw Router handlers** (registered via [`.command()`](Self::command),
613    ///    [`.command_json()`](Self::command_json), [`.command_binary()`](Self::command_binary))
614    ///    are the fallback. These are simpler `Vec<u8> -> Vec<u8>` functions
615    ///    with no injection or async support.
616    ///
617    /// If a command name exists in both tiers, the `#[command]` handler takes
618    /// priority and a debug warning is printed.
619    pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
620        let commands = self.commands;
621        let handler_defs = self.handler_defs;
622        let channel_defs = self.channel_defs;
623
624        TauriPluginBuilder::<R>::new("conduit")
625            // --- Custom protocol: conduit://localhost/invoke/<cmd> ---
626            // Uses the asynchronous variant so async #[command] handlers
627            // are spawned on tokio (truly async, like #[tauri::command]).
628            .register_asynchronous_uri_scheme_protocol("conduit", move |ctx, request, responder| {
629                // Handle CORS preflight requests.
630                if request.method() == "OPTIONS" {
631                    let resp = http::Response::builder()
632                        .status(204)
633                        .header("Access-Control-Allow-Origin", "*")
634                        .header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
635                        .header(
636                            "Access-Control-Allow-Headers",
637                            "Content-Type, X-Conduit-Key, X-Conduit-Webview",
638                        )
639                        .header("Access-Control-Max-Age", "86400")
640                        .body(Vec::new())
641                        .expect("preflight response must not fail");
642                    responder.respond(resp);
643                    return;
644                }
645
646                // Extract the managed PluginState from the app handle.
647                let state: tauri::State<'_, PluginState<R>> = ctx.app_handle().state();
648
649                // Extract path directly from the URI — zero allocation.
650                let path = request.uri().path();
651                let segments: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
652
653                if segments.len() != 2 {
654                    responder.respond(make_error_response(
655                        404,
656                        "not found: expected /invoke/<cmd> or /drain/<channel>",
657                    ));
658                    return;
659                }
660
661                // Validate the invoke key from the X-Conduit-Key header.
662                // Borrow the header value directly — no allocation needed.
663                let key = match request.headers().get("X-Conduit-Key") {
664                    Some(v) => match v.to_str() {
665                        Ok(s) => s,
666                        Err(_) => {
667                            responder
668                                .respond(make_error_response(401, "invalid invoke key header"));
669                            return;
670                        }
671                    },
672                    None => {
673                        responder.respond(make_error_response(401, "missing invoke key"));
674                        return;
675                    }
676                };
677
678                if !state.validate_invoke_key(key) {
679                    responder.respond(make_error_response(403, "invalid invoke key"));
680                    return;
681                }
682
683                let action = segments[0];
684                let raw_target = segments[1];
685
686                // H6: Percent-decode the target and reject path traversal.
687                let target = percent_decode(raw_target);
688                if target.contains('/') {
689                    responder.respond(make_error_response(400, "invalid command name"));
690                    return;
691                }
692
693                match action {
694                    "invoke" => {
695                        let body = request.body().to_vec();
696
697                        // 1) Check #[command]-generated handlers first (sync or async)
698                        if let Some(handler) = state.handlers.get(&*target) {
699                            let handler = Arc::clone(handler);
700                            // Extract webview label from X-Conduit-Webview header (sent by JS client).
701                            // NOTE: This header is client-provided and could be spoofed by JS
702                            // running in the same webview. We validate the format to prevent
703                            // injection attacks, but in a multi-webview app, code in one
704                            // webview could impersonate another. This matches Tauri's own
705                            // trust model where all JS in the webview is equally trusted.
706                            let webview_label = request
707                                .headers()
708                                .get("X-Conduit-Webview")
709                                .and_then(|v| v.to_str().ok())
710                                .filter(|s| {
711                                    !s.is_empty()
712                                        && s.len() <= 128
713                                        && s.bytes().all(|b| {
714                                            b.is_ascii_alphanumeric() || b == b'_' || b == b'-'
715                                        })
716                                })
717                                .map(|s| s.to_string());
718                            // Clone the pre-cached Arc and coerce to trait object —
719                            // one atomic increment, no heap allocation.
720                            let app_handle_arc: Arc<dyn std::any::Any + Send + Sync> =
721                                state.app_handle_arc.clone();
722                            let handler_ctx = conduit_core::HandlerContext::new(
723                                app_handle_arc,
724                                webview_label,
725                            );
726                            let ctx_any: Arc<dyn std::any::Any + Send + Sync> =
727                                Arc::new(handler_ctx);
728
729                            // SAFETY: AssertUnwindSafe is used here because:
730                            // - `body` is a Vec<u8> (unwind-safe by itself)
731                            // - `ctx_any` is an Arc (unwind-safe)
732                            // - conduit's own locks use poison-recovery helpers (lock_or_recover)
733                            // - User-defined handler state may be left inconsistent after panic,
734                            //   but this is inherent to catch_unwind and documented as a limitation.
735                            let result =
736                                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
737                                    handler.call(body, ctx_any)
738                                }));
739
740                            match result {
741                                Ok(HandlerResponse::Sync(Ok(bytes))) => {
742                                    responder.respond(make_response(
743                                        200,
744                                        "application/octet-stream",
745                                        bytes,
746                                    ));
747                                }
748                                Ok(HandlerResponse::Sync(Err(e))) => {
749                                    let status = error_to_status(&e);
750                                    responder
751                                        .respond(make_error_response(status, &sanitize_error(&e)));
752                                }
753                                Ok(HandlerResponse::Async(future)) => {
754                                    // Truly async — spawned on tokio, just like #[tauri::command].
755                                    // Single spawn with catch_unwind for panic isolation.
756                                    tauri::async_runtime::spawn(async move {
757                                        let result = std::panic::AssertUnwindSafe(future)
758                                            .catch_unwind()
759                                            .await;
760                                        match result {
761                                            Ok(Ok(bytes)) => {
762                                                responder.respond(make_response(
763                                                    200,
764                                                    "application/octet-stream",
765                                                    bytes,
766                                                ));
767                                            }
768                                            Ok(Err(e)) => {
769                                                let status = error_to_status(&e);
770                                                responder.respond(make_error_response(
771                                                    status,
772                                                    &sanitize_error(&e),
773                                                ));
774                                            }
775                                            Err(_) => {
776                                                // Panic during async handler execution
777                                                responder.respond(make_error_response(
778                                                    500,
779                                                    "handler panicked",
780                                                ));
781                                            }
782                                        }
783                                    });
784                                }
785                                Err(_) => {
786                                    // Panic caught by catch_unwind — keep as 500.
787                                    responder.respond(make_error_response(500, "handler panicked"));
788                                }
789                            }
790                        } else {
791                            // 2) Fall back to legacy sync Router
792                            let dispatch = Arc::clone(&state.dispatch);
793                            // Use the app_handle reference from state — no clone needed.
794                            let app_handle_ref = &state.app_handle;
795                            // SAFETY: AssertUnwindSafe is used here because:
796                            // - `body` is a Vec<u8> (unwind-safe by itself)
797                            // - `dispatch` is an Arc<Router> (unwind-safe)
798                            // - `app_handle_ref` borrows from Tauri state (unwind-safe)
799                            // - conduit's own locks use poison-recovery helpers (lock_or_recover)
800                            // - User-defined handler state may be left inconsistent after panic,
801                            //   but this is inherent to catch_unwind and documented as a limitation.
802                            let result =
803                                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
804                                    dispatch.call_with_context(&target, body, app_handle_ref)
805                                }));
806                            match result {
807                                Ok(Ok(bytes)) => {
808                                    responder.respond(make_response(
809                                        200,
810                                        "application/octet-stream",
811                                        bytes,
812                                    ));
813                                }
814                                Ok(Err(e)) => {
815                                    let status = error_to_status(&e);
816                                    responder
817                                        .respond(make_error_response(status, &sanitize_error(&e)));
818                                }
819                                Err(_) => {
820                                    // Panic caught by catch_unwind — keep as 500.
821                                    responder.respond(make_error_response(500, "handler panicked"));
822                                }
823                            }
824                        }
825                    }
826                    "drain" => match state.channel(&target) {
827                        Some(ch) => {
828                            let blob = ch.drain_all();
829                            responder.respond(make_response(200, "application/octet-stream", blob));
830                        }
831                        None => {
832                            responder.respond(make_error_response(
833                                404,
834                                &format!("unknown channel: {}", sanitize_name(&target)),
835                            ));
836                        }
837                    },
838                    _ => {
839                        responder.respond(make_error_response(
840                            404,
841                            "not found: expected /invoke/<cmd> or /drain/<channel>",
842                        ));
843                    }
844                }
845            })
846            // --- Register Tauri IPC commands ---
847            .invoke_handler(tauri::generate_handler![
848                conduit_bootstrap,
849                conduit_subscribe,
850            ])
851            // --- Plugin setup: create state, register commands ---
852            .setup(move |app, _api| {
853                let dispatch = Arc::new(Router::new());
854
855                // Register all old-style commands that were added via the builder.
856                for register_fn in commands {
857                    register_fn(&dispatch);
858                }
859
860                // Build the #[command] handler map, checking for collisions
861                // with Router commands.
862                let mut handler_map = HashMap::new();
863                for (name, handler) in handler_defs {
864                    if dispatch.has(&name) {
865                        #[cfg(debug_assertions)]
866                        eprintln!(
867                            "conduit: warning: handler '{name}' shadows a Router command \
868                             with the same name — the #[command] handler takes priority"
869                        );
870                    }
871                    handler_map.insert(name, handler);
872                }
873                let handlers = Arc::new(handler_map);
874
875                // Create named channels.
876                let mut channels = HashMap::new();
877                for (name, kind) in channel_defs {
878                    let buf = match kind {
879                        ChannelKind::Lossy(cap) => ChannelBuffer::Lossy(RingBuffer::new(cap)),
880                        ChannelKind::Reliable(max_bytes) => {
881                            ChannelBuffer::Reliable(Queue::new(max_bytes))
882                        }
883                    };
884                    channels.insert(name, Arc::new(buf));
885                }
886
887                // Generate the per-launch invoke key.
888                let invoke_key_bytes = generate_invoke_key_bytes();
889                let invoke_key = hex_encode(&invoke_key_bytes);
890
891                // Obtain the app handle for emitting events.
892                let app_handle = app.app_handle().clone();
893                let app_handle_arc = Arc::new(app_handle.clone());
894
895                let state = PluginState {
896                    dispatch,
897                    handlers,
898                    channels,
899                    app_handle,
900                    app_handle_arc,
901                    invoke_key,
902                    invoke_key_bytes,
903                };
904
905                app.manage(state);
906
907                Ok(())
908            })
909            .build()
910    }
911}
912
913impl Default for PluginBuilder {
914    fn default() -> Self {
915        Self::new()
916    }
917}
918
919// ---------------------------------------------------------------------------
920// Public init function
921// ---------------------------------------------------------------------------
922
923/// Create a new conduit plugin builder.
924///
925/// This is the main entry point for using the conduit Tauri plugin:
926///
927/// ```rust,ignore
928/// use tauri_conduit::command;
929///
930/// #[command]
931/// fn greet(name: String) -> String {
932///     format!("Hello, {name}!")
933/// }
934///
935/// #[command]
936/// async fn fetch_data(url: String) -> Result<Vec<u8>, String> {
937///     reqwest::get(&url).await.map_err(|e| e.to_string())?
938///         .bytes().await.map(|b| b.to_vec()).map_err(|e| e.to_string())
939/// }
940///
941/// tauri::Builder::default()
942///     .plugin(
943///         tauri_plugin_conduit::init()
944///             .handler("greet", handler!(greet))
945///             .handler("fetch_data", handler!(fetch_data))
946///             .channel("telemetry")
947///             .build()
948///     )
949///     .run(tauri::generate_context!())
950///     .unwrap();
951/// ```
952pub fn init() -> PluginBuilder {
953    PluginBuilder::new()
954}
955
956// ---------------------------------------------------------------------------
957// Helpers
958// ---------------------------------------------------------------------------
959
960/// Map a [`conduit_core::Error`] to the appropriate HTTP status code.
961fn error_to_status(e: &conduit_core::Error) -> u16 {
962    match e {
963        conduit_core::Error::UnknownCommand(_) => 404,
964        conduit_core::Error::UnknownChannel(_) => 404,
965        conduit_core::Error::AuthFailed => 403,
966        conduit_core::Error::DecodeFailed => 400,
967        conduit_core::Error::PayloadTooLarge(_) => 413,
968        conduit_core::Error::Handler(_) => 500,
969        conduit_core::Error::Serialize(_) => 500,
970        conduit_core::Error::ChannelFull => 500,
971    }
972}
973
974/// Truncate a user-supplied name to 64 bytes and strip control characters
975/// to prevent log injection and oversized error messages.
976///
977/// Truncation respects UTF-8 character boundaries — the output is always
978/// valid UTF-8 with at most 64 bytes of text content.
979fn sanitize_name(name: &str) -> String {
980    let truncated = if name.len() > 64 {
981        // Walk back from byte 64 to find a valid char boundary.
982        let mut end = 64;
983        while end > 0 && !name.is_char_boundary(end) {
984            end -= 1;
985        }
986        &name[..end]
987    } else {
988        name
989    };
990    truncated.chars().filter(|c| !c.is_control()).collect()
991}
992
993/// Format a [`conduit_core::Error`] for inclusion in HTTP error responses,
994/// sanitizing any embedded user-supplied names (command or channel names).
995fn sanitize_error(e: &conduit_core::Error) -> String {
996    match e {
997        conduit_core::Error::UnknownCommand(name) => {
998            format!("unknown command: {}", sanitize_name(name))
999        }
1000        conduit_core::Error::UnknownChannel(name) => {
1001            format!("unknown channel: {}", sanitize_name(name))
1002        }
1003        other => other.to_string(),
1004    }
1005}
1006
1007/// Percent-decode a URL path segment (e.g., `hello%20world` → `hello world`).
1008///
1009/// Returns `Cow::Borrowed` when no percent-encoding is present (the common
1010/// case), avoiding a heap allocation entirely.
1011fn percent_decode(input: &str) -> std::borrow::Cow<'_, str> {
1012    // Fast path: no percent-encoded characters — return the input as-is.
1013    if !input.as_bytes().contains(&b'%') {
1014        return std::borrow::Cow::Borrowed(input);
1015    }
1016    let mut result = Vec::with_capacity(input.len());
1017    let bytes = input.as_bytes();
1018    let mut i = 0;
1019    while i < bytes.len() {
1020        if bytes[i] == b'%' && i + 2 < bytes.len() {
1021            if let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
1022                result.push(hi << 4 | lo);
1023                i += 3;
1024                continue;
1025            }
1026        }
1027        result.push(bytes[i]);
1028        i += 1;
1029    }
1030    std::borrow::Cow::Owned(String::from_utf8_lossy(&result).into_owned())
1031}
1032
1033/// Convert a single ASCII hex character to its 4-bit numeric value.
1034///
1035/// Unlike [`hex_digit_ct`], this does NOT need to be constant-time — it is
1036/// used for URL percent-decoding, not security-critical key validation.
1037fn hex_val(b: u8) -> Option<u8> {
1038    match b {
1039        b'0'..=b'9' => Some(b - b'0'),
1040        b'a'..=b'f' => Some(b - b'a' + 10),
1041        b'A'..=b'F' => Some(b - b'A' + 10),
1042        _ => None,
1043    }
1044}
1045
1046/// Generate 32 random bytes for the per-launch invoke key.
1047fn generate_invoke_key_bytes() -> [u8; 32] {
1048    let mut bytes = [0u8; 32];
1049    getrandom::fill(&mut bytes).expect("conduit: failed to generate invoke key");
1050    bytes
1051}
1052
1053/// Hex-encode a byte slice (no per-byte allocation).
1054fn hex_encode(bytes: &[u8]) -> String {
1055    const HEX: &[u8; 16] = b"0123456789abcdef";
1056    let mut hex = String::with_capacity(bytes.len() * 2);
1057    for &b in bytes {
1058        hex.push(HEX[(b >> 4) as usize] as char);
1059        hex.push(HEX[(b & 0x0f) as usize] as char);
1060    }
1061    hex
1062}
1063
1064/// Hex-decode a string into bytes. Returns `None` on invalid input.
1065///
1066/// This is the non-constant-time version used for non-security paths.
1067/// For invoke key validation, see [`hex_digit_ct`] and the constant-time
1068/// path in [`PluginState::validate_invoke_key`].
1069#[cfg(test)]
1070fn hex_decode(hex: &str) -> Option<Vec<u8>> {
1071    if hex.len() % 2 != 0 {
1072        return None;
1073    }
1074    let mut bytes = Vec::with_capacity(hex.len() / 2);
1075    for chunk in hex.as_bytes().chunks(2) {
1076        let hi = hex_digit(chunk[0])?;
1077        let lo = hex_digit(chunk[1])?;
1078        bytes.push((hi << 4) | lo);
1079    }
1080    Some(bytes)
1081}
1082
1083/// Convert a single ASCII hex character to its 4-bit numeric value.
1084#[cfg(test)]
1085fn hex_digit(b: u8) -> Option<u8> {
1086    match b {
1087        b'0'..=b'9' => Some(b - b'0'),
1088        b'a'..=b'f' => Some(b - b'a' + 10),
1089        b'A'..=b'F' => Some(b - b'A' + 10),
1090        _ => None,
1091    }
1092}
1093
1094/// Validate an invoke key candidate using constant-time operations.
1095///
1096/// The length check (must be exactly 64 hex chars) is not constant-time
1097/// because the expected length is public knowledge. The hex decode and
1098/// byte comparison are fully constant-time: no early returns on invalid
1099/// characters, and the comparison always runs even if decode failed.
1100fn validate_invoke_key_ct(expected: &[u8; 32], candidate: &str) -> bool {
1101    let candidate_bytes = candidate.as_bytes();
1102
1103    // Length is not secret — always 64 hex chars for 32 bytes.
1104    if candidate_bytes.len() != 64 {
1105        return false;
1106    }
1107
1108    // Constant-time hex decode: always process all 32 byte pairs.
1109    let mut decoded = [0u8; 32];
1110    let mut all_valid = 1u8;
1111
1112    for i in 0..32 {
1113        let (hi_val, hi_ok) = hex_digit_ct(candidate_bytes[i * 2]);
1114        let (lo_val, lo_ok) = hex_digit_ct(candidate_bytes[i * 2 + 1]);
1115        decoded[i] = (hi_val << 4) | lo_val;
1116        all_valid &= hi_ok & lo_ok;
1117    }
1118
1119    // Always compare, even if some hex digits were invalid.
1120    let cmp_ok: bool = expected.ct_eq(&decoded).into();
1121
1122    // Combine with bitwise AND — no short-circuit.
1123    (all_valid == 1) & cmp_ok
1124}
1125
1126/// Constant-time hex digit decode for security-critical paths.
1127///
1128/// Returns `(value, valid)` where `valid` is `1` if the byte is a valid
1129/// hex character and `0` otherwise. All operations are bitwise — no
1130/// comparisons, no branches. When invalid, `value` is `0`.
1131///
1132/// Uses subtraction + sign-bit masking on `i16` to produce range-check
1133/// masks without any comparison operators that could compile to branches.
1134fn hex_digit_ct(b: u8) -> (u8, u8) {
1135    // Promote to i16 so wrapping_sub produces a sign bit we can extract.
1136    let b = b as i16;
1137
1138    // Check if b is in '0'..='9'  (0x30..=0x39)
1139    let d = b.wrapping_sub(0x30); // b - '0'
1140    // d >= 0 && d < 10: (!d) is negative iff d >= 0; (d-10) is negative iff d < 10.
1141    // Combining via AND and extracting the sign bit gives us a mask.
1142    let digit_mask = ((!d) & (d.wrapping_sub(10))) >> 15;
1143    let digit_mask = (digit_mask & 1) as u8;
1144
1145    // Check if b is in 'a'..='f'  (0x61..=0x66)
1146    let l = b.wrapping_sub(0x61); // b - 'a'
1147    let lower_mask = ((!l) & (l.wrapping_sub(6))) >> 15;
1148    let lower_mask = (lower_mask & 1) as u8;
1149
1150    // Check if b is in 'A'..='F'  (0x41..=0x46)
1151    let u = b.wrapping_sub(0x41); // b - 'A'
1152    let upper_mask = ((!u) & (u.wrapping_sub(6))) >> 15;
1153    let upper_mask = (upper_mask & 1) as u8;
1154
1155    let val = ((d as u8 & 0x0f) & digit_mask.wrapping_neg())
1156        .wrapping_add((l as u8).wrapping_add(10) & lower_mask.wrapping_neg())
1157        .wrapping_add((u as u8).wrapping_add(10) & upper_mask.wrapping_neg());
1158    let valid = digit_mask | lower_mask | upper_mask;
1159
1160    (val, valid)
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165    use super::*;
1166
1167    // -- hex encode/decode tests --
1168
1169    #[test]
1170    fn hex_encode_roundtrip() {
1171        assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
1172        assert_eq!(hex_encode(&[]), "");
1173        assert_eq!(hex_encode(&[0x00, 0xff]), "00ff");
1174    }
1175
1176    #[test]
1177    fn hex_decode_valid() {
1178        assert_eq!(hex_decode("deadbeef"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1179        assert_eq!(hex_decode(""), Some(vec![]));
1180        assert_eq!(hex_decode("00ff"), Some(vec![0x00, 0xff]));
1181    }
1182
1183    #[test]
1184    fn hex_decode_uppercase() {
1185        assert_eq!(hex_decode("DEADBEEF"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1186        assert_eq!(hex_decode("DeAdBeEf"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1187    }
1188
1189    #[test]
1190    fn hex_decode_odd_length() {
1191        assert_eq!(hex_decode("abc"), None);
1192        assert_eq!(hex_decode("a"), None);
1193    }
1194
1195    #[test]
1196    fn hex_decode_invalid_chars() {
1197        assert_eq!(hex_decode("zz"), None);
1198        assert_eq!(hex_decode("gg"), None);
1199        assert_eq!(hex_decode("0x"), None);
1200    }
1201
1202    #[test]
1203    fn hex_roundtrip_32_bytes() {
1204        let original = generate_invoke_key_bytes();
1205        let encoded = hex_encode(&original);
1206        assert_eq!(encoded.len(), 64);
1207        let decoded = hex_decode(&encoded).unwrap();
1208        assert_eq!(decoded, original);
1209    }
1210
1211    // -- constant-time hex tests --
1212
1213    #[test]
1214    fn hex_digit_ct_valid_chars() {
1215        for b in b'0'..=b'9' {
1216            let (val, valid) = hex_digit_ct(b);
1217            assert_eq!(valid, 1, "digit {b} should be valid");
1218            assert_eq!(val, b - b'0');
1219        }
1220        for b in b'a'..=b'f' {
1221            let (val, valid) = hex_digit_ct(b);
1222            assert_eq!(valid, 1, "lower {b} should be valid");
1223            assert_eq!(val, b - b'a' + 10);
1224        }
1225        for b in b'A'..=b'F' {
1226            let (val, valid) = hex_digit_ct(b);
1227            assert_eq!(valid, 1, "upper {b} should be valid");
1228            assert_eq!(val, b - b'A' + 10);
1229        }
1230    }
1231
1232    #[test]
1233    fn hex_digit_ct_invalid_chars() {
1234        for &b in &[b'g', b'z', b'G', b'Z', b' ', b'\0', b'/', b':', b'@', b'`'] {
1235            let (_val, valid) = hex_digit_ct(b);
1236            assert_eq!(valid, 0, "char {b} should be invalid");
1237        }
1238    }
1239
1240    #[test]
1241    fn hex_digit_ct_matches_hex_digit() {
1242        for b in 0..=255u8 {
1243            let ct_result = hex_digit_ct(b);
1244            let std_result = hex_digit(b);
1245            match std_result {
1246                Some(v) => {
1247                    assert_eq!(ct_result.1, 1, "mismatch at {b}: ct says invalid");
1248                    assert_eq!(ct_result.0, v, "value mismatch at {b}");
1249                }
1250                None => {
1251                    assert_eq!(ct_result.1, 0, "mismatch at {b}: ct says valid");
1252                }
1253            }
1254        }
1255    }
1256
1257    // -- make_response tests --
1258
1259    #[test]
1260    fn make_response_200() {
1261        let resp = make_response(200, "application/octet-stream", b"hello".to_vec());
1262        assert_eq!(resp.status(), 200);
1263        assert_eq!(resp.body(), b"hello");
1264    }
1265
1266    #[test]
1267    fn make_response_404() {
1268        let resp = make_response(404, "text/plain", b"not found".to_vec());
1269        assert_eq!(resp.status(), 404);
1270        assert_eq!(resp.body(), b"not found");
1271    }
1272
1273    // -- State<T> injection tests --
1274
1275    #[command]
1276    fn with_state(state: tauri::State<'_, String>, name: String) -> String {
1277        format!("{}: {name}", state.as_str())
1278    }
1279
1280    #[test]
1281    fn state_injection_wrong_context_returns_error() {
1282        use conduit_core::ConduitHandler;
1283        use conduit_derive::handler;
1284
1285        let payload = serde_json::to_vec(&serde_json::json!({ "name": "test" })).unwrap();
1286        let wrong_ctx: Arc<dyn std::any::Any + Send + Sync> = Arc::new(());
1287
1288        match handler!(with_state).call(payload, wrong_ctx) {
1289            conduit_core::HandlerResponse::Sync(Err(conduit_core::Error::Handler(msg))) => {
1290                assert!(
1291                    msg.contains("handler context must be HandlerContext"),
1292                    "unexpected error message: {msg}"
1293                );
1294            }
1295            _ => panic!("expected Sync(Err(Handler))"),
1296        }
1297    }
1298
1299    #[test]
1300    fn original_state_function_preserved() {
1301        // The original function with_state is preserved and callable directly.
1302        // We can't call it without an actual Tauri State, but we can verify
1303        // the function exists and has the right signature by taking a reference.
1304        let _fn_ref: fn(tauri::State<'_, String>, String) -> String = with_state;
1305    }
1306
1307    // -- validate_invoke_key tests --
1308
1309    #[test]
1310    fn validate_invoke_key_correct() {
1311        let key = [0xab_u8; 32];
1312        let hex = hex_encode(&key);
1313        assert!(validate_invoke_key_ct(&key, &hex));
1314    }
1315
1316    #[test]
1317    fn validate_invoke_key_wrong_key() {
1318        let key = [0xab_u8; 32];
1319        let wrong = hex_encode(&[0x00_u8; 32]);
1320        assert!(!validate_invoke_key_ct(&key, &wrong));
1321    }
1322
1323    #[test]
1324    fn validate_invoke_key_wrong_length() {
1325        let key = [0xab_u8; 32];
1326        assert!(!validate_invoke_key_ct(&key, "abcdef"));
1327        assert!(!validate_invoke_key_ct(&key, ""));
1328        assert!(!validate_invoke_key_ct(&key, &"a".repeat(63)));
1329        assert!(!validate_invoke_key_ct(&key, &"a".repeat(65)));
1330    }
1331
1332    #[test]
1333    fn validate_invoke_key_invalid_hex() {
1334        let key = [0xab_u8; 32];
1335        // 64 chars but invalid hex
1336        assert!(!validate_invoke_key_ct(&key, &"zz".repeat(32)));
1337        assert!(!validate_invoke_key_ct(&key, &"gg".repeat(32)));
1338    }
1339
1340    #[test]
1341    fn validate_invoke_key_uppercase_accepted() {
1342        let key = [0xab_u8; 32];
1343        let hex = hex_encode(&key);
1344        // hex_digit_ct handles uppercase, so uppercase of a valid key should match
1345        assert!(validate_invoke_key_ct(&key, &hex.to_uppercase()));
1346    }
1347
1348    #[test]
1349    fn validate_invoke_key_random_roundtrip() {
1350        let key = generate_invoke_key_bytes();
1351        let hex = hex_encode(&key);
1352        assert!(validate_invoke_key_ct(&key, &hex));
1353    }
1354
1355    // -- make_error_response tests --
1356
1357    #[test]
1358    fn make_error_response_json_format() {
1359        let resp = make_error_response(500, "something failed");
1360        assert_eq!(resp.status(), 500);
1361        let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
1362        assert_eq!(body["error"], "something failed");
1363    }
1364
1365    #[test]
1366    fn make_error_response_escapes_special_chars() {
1367        let resp = make_error_response(400, r#"bad "input" with \ slash"#);
1368        let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
1369        assert_eq!(body["error"], r#"bad "input" with \ slash"#);
1370    }
1371
1372    // -- percent_decode tests --
1373
1374    #[test]
1375    fn percent_decode_no_encoding() {
1376        assert_eq!(percent_decode("hello"), "hello");
1377        assert_eq!(percent_decode("foo-bar_baz"), "foo-bar_baz");
1378    }
1379
1380    #[test]
1381    fn percent_decode_basic() {
1382        assert_eq!(percent_decode("hello%20world"), "hello world");
1383        assert_eq!(percent_decode("%2F"), "/");
1384        assert_eq!(percent_decode("%2f"), "/");
1385    }
1386
1387    #[test]
1388    fn percent_decode_multiple() {
1389        assert_eq!(percent_decode("a%20b%20c"), "a b c");
1390        assert_eq!(percent_decode("%41%42%43"), "ABC");
1391    }
1392
1393    #[test]
1394    fn percent_decode_incomplete_sequence() {
1395        // Incomplete %XX at end — pass through unchanged.
1396        assert_eq!(percent_decode("hello%2"), "hello%2");
1397        assert_eq!(percent_decode("hello%"), "hello%");
1398    }
1399
1400    #[test]
1401    fn percent_decode_invalid_hex() {
1402        // Invalid hex chars after % — pass through unchanged.
1403        assert_eq!(percent_decode("hello%GG"), "hello%GG");
1404        assert_eq!(percent_decode("%ZZ"), "%ZZ");
1405    }
1406
1407    #[test]
1408    fn percent_decode_empty() {
1409        assert_eq!(percent_decode(""), "");
1410    }
1411
1412    // -- sanitize_name tests --
1413
1414    #[test]
1415    fn sanitize_name_short() {
1416        assert_eq!(sanitize_name("hello"), "hello");
1417    }
1418
1419    #[test]
1420    fn sanitize_name_truncates_long() {
1421        let long = "a".repeat(100);
1422        assert_eq!(sanitize_name(&long).len(), 64);
1423    }
1424
1425    #[test]
1426    fn sanitize_name_strips_control_chars() {
1427        assert_eq!(sanitize_name("hello\x00world"), "helloworld");
1428        assert_eq!(sanitize_name("foo\nbar\rbaz"), "foobarbaz");
1429    }
1430
1431    #[test]
1432    fn sanitize_name_multibyte_utf8() {
1433        // "a" repeated 63 times + "é" (2 bytes: 0xC3 0xA9) = 65 bytes total.
1434        // Byte 64 is the second byte of "é", not a char boundary.
1435        // Must not panic — should truncate to the last valid boundary (63 'a's).
1436        let name = format!("{}{}", "a".repeat(63), "é");
1437        assert_eq!(name.len(), 65);
1438        let sanitized = sanitize_name(&name);
1439        assert_eq!(sanitized, "a".repeat(63));
1440
1441        // 4-byte character crossing the 64-byte boundary.
1442        let name = format!("{}🦀", "a".repeat(62)); // 62 + 4 = 66 bytes
1443        assert_eq!(name.len(), 66);
1444        let sanitized = sanitize_name(&name);
1445        assert_eq!(sanitized, "a".repeat(62));
1446
1447        // Exactly 64 bytes of ASCII — no truncation needed.
1448        let name = "a".repeat(64);
1449        assert_eq!(sanitize_name(&name), "a".repeat(64));
1450    }
1451
1452    // -- error_to_status tests --
1453
1454    #[test]
1455    fn error_to_status_mapping() {
1456        use conduit_core::Error;
1457        assert_eq!(error_to_status(&Error::UnknownCommand("x".into())), 404);
1458        assert_eq!(error_to_status(&Error::UnknownChannel("x".into())), 404);
1459        assert_eq!(error_to_status(&Error::AuthFailed), 403);
1460        assert_eq!(error_to_status(&Error::DecodeFailed), 400);
1461        assert_eq!(error_to_status(&Error::PayloadTooLarge(999)), 413);
1462        assert_eq!(error_to_status(&Error::Handler("x".into())), 500);
1463        assert_eq!(error_to_status(&Error::ChannelFull), 500);
1464    }
1465
1466    // -- channel validation tests --
1467
1468    #[test]
1469    fn validate_channel_name_valid() {
1470        validate_channel_name("telemetry");
1471        validate_channel_name("my-channel");
1472        validate_channel_name("my_channel");
1473        validate_channel_name("Channel123");
1474        validate_channel_name("a");
1475    }
1476
1477    #[test]
1478    #[should_panic(expected = "invalid channel name")]
1479    fn validate_channel_name_empty() {
1480        validate_channel_name("");
1481    }
1482
1483    #[test]
1484    #[should_panic(expected = "invalid channel name")]
1485    fn validate_channel_name_spaces() {
1486        validate_channel_name("my channel");
1487    }
1488
1489    #[test]
1490    #[should_panic(expected = "invalid channel name")]
1491    fn validate_channel_name_special_chars() {
1492        validate_channel_name("my.channel");
1493    }
1494
1495    #[test]
1496    #[should_panic(expected = "duplicate channel name")]
1497    fn duplicate_channel_panics() {
1498        PluginBuilder::new()
1499            .channel("telemetry")
1500            .channel("telemetry");
1501    }
1502
1503    #[test]
1504    #[should_panic(expected = "duplicate channel name")]
1505    fn duplicate_channel_different_kinds_panics() {
1506        PluginBuilder::new().channel("data").channel_ordered("data");
1507    }
1508}