tauri_plugin_conduit/lib.rs
1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3//! # tauri-plugin-conduit
4//!
5//! Tauri v2 plugin for conduit — binary IPC over the `conduit://` custom
6//! protocol.
7//!
8//! Registers a `conduit://` custom protocol for zero-overhead in-process
9//! binary dispatch. Supports both sync and async handlers via
10//! [`ConduitHandler`](conduit_core::ConduitHandler). No network surface.
11//!
12//! ## Usage
13//!
14//! ```rust,ignore
15//! use tauri_conduit::{command, handler};
16//!
17//! #[command]
18//! fn greet(name: String) -> String {
19//! format!("Hello, {name}!")
20//! }
21//!
22//! #[command]
23//! async fn fetch_user(state: State<'_, Db>, id: u64) -> Result<User, String> {
24//! state.get_user(id).await.map_err(|e| e.to_string())
25//! }
26//!
27//! tauri::Builder::default()
28//! .plugin(
29//! tauri_plugin_conduit::init()
30//! .handler("greet", handler!(greet))
31//! .handler("fetch_user", handler!(fetch_user))
32//! .channel("telemetry")
33//! .build()
34//! )
35//! .run(tauri::generate_context!())
36//! .unwrap();
37//! ```
38
39/// Re-export the `#[command]` attribute macro from `conduit-derive`.
40///
41/// This is conduit's equivalent of `#[tauri::command]`. Use it for
42/// named-parameter handlers:
43///
44/// ```rust,ignore
45/// use tauri_conduit::{command, handler};
46///
47/// #[command]
48/// fn greet(name: String, greeting: String) -> String {
49/// format!("{greeting}, {name}!")
50/// }
51/// ```
52pub use conduit_derive::command;
53
54/// Re-export the `handler!()` macro from `conduit-derive`.
55///
56/// Resolves a `#[command]` function name to its conduit handler struct
57/// for registration:
58///
59/// ```rust,ignore
60/// tauri_plugin_conduit::init()
61/// .handler("greet", handler!(greet))
62/// .build()
63/// ```
64pub use conduit_derive::handler;
65
66use std::collections::HashMap;
67use std::sync::Arc;
68
69use conduit_core::{
70 ChannelBuffer, ConduitHandler, Decode, Encode, HandlerResponse, Queue, RingBuffer, Router,
71};
72use subtle::ConstantTimeEq;
73use tauri::plugin::{Builder as TauriPluginBuilder, TauriPlugin};
74use tauri::{AppHandle, Emitter, Manager, Runtime};
75
76// ---------------------------------------------------------------------------
77// Helper: safe HTTP response builder
78// ---------------------------------------------------------------------------
79
80/// Build an HTTP response, falling back to a minimal 500 if construction fails.
81fn make_response(status: u16, content_type: &str, body: Vec<u8>) -> http::Response<Vec<u8>> {
82 http::Response::builder()
83 .status(status)
84 .header("Content-Type", content_type)
85 .header("Access-Control-Allow-Origin", "*")
86 .body(body)
87 .unwrap_or_else(|_| {
88 http::Response::builder()
89 .status(500)
90 .body(b"internal error".to_vec())
91 .expect("fallback response must not fail")
92 })
93}
94
95/// Build a JSON error response: `{"error": "message"}`.
96///
97/// Uses `sonic_rs` for proper RFC 8259 escaping of all control characters,
98/// newlines, quotes, and backslashes — not just `\` and `"`.
99fn make_error_response(status: u16, message: &str) -> http::Response<Vec<u8>> {
100 #[derive(serde::Serialize)]
101 struct ErrorBody<'a> {
102 error: &'a str,
103 }
104 let body = conduit_core::sonic_rs::to_vec(&ErrorBody { error: message })
105 .unwrap_or_else(|_| br#"{"error":"internal error"}"#.to_vec());
106 make_response(status, "application/json", body)
107}
108
109// ---------------------------------------------------------------------------
110// BootstrapInfo — returned to JS via `conduit_bootstrap` command
111// ---------------------------------------------------------------------------
112
113/// Connection info returned to the frontend during bootstrap.
114#[derive(Clone, serde::Serialize, serde::Deserialize)]
115#[serde(rename_all = "camelCase")]
116pub struct BootstrapInfo {
117 /// Protocol version (currently `1`). Allows the TS client to verify
118 /// protocol compatibility.
119 #[serde(default = "default_protocol_version")]
120 pub protocol_version: u8,
121 /// Base URL for the custom protocol (e.g., `"conduit://localhost"`).
122 pub protocol_base: String,
123 /// Per-launch invoke key for custom protocol authentication (hex-encoded).
124 ///
125 /// **Security**: This key authenticates custom protocol requests. It is
126 /// generated fresh each launch from 32 bytes of OS randomness and validated
127 /// using constant-time comparison. The JS client includes it as the
128 /// `X-Conduit-Key` header on every `conduit://` request.
129 pub invoke_key: String,
130 /// Available channel names.
131 pub channels: Vec<String>,
132}
133
134fn default_protocol_version() -> u8 {
135 1
136}
137
138impl std::fmt::Debug for BootstrapInfo {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 f.debug_struct("BootstrapInfo")
141 .field("protocol_version", &self.protocol_version)
142 .field("protocol_base", &self.protocol_base)
143 .field("invoke_key", &"[REDACTED]")
144 .field("channels", &self.channels)
145 .finish()
146 }
147}
148
149// ---------------------------------------------------------------------------
150// PluginState — managed Tauri state
151// ---------------------------------------------------------------------------
152
153/// Shared state for the conduit Tauri plugin.
154///
155/// Holds the router, named streaming channels, the per-launch invoke key,
156/// and the app handle for emitting push notifications.
157pub struct PluginState<R: Runtime> {
158 dispatch: Arc<Router>,
159 /// `#[command]`-generated handlers (sync and async via [`ConduitHandler`]).
160 handlers: Arc<HashMap<String, Arc<dyn ConduitHandler>>>,
161 /// Named channels for server→client streaming (lossy or ordered).
162 channels: HashMap<String, Arc<ChannelBuffer>>,
163 /// Tauri app handle for emitting events to the frontend.
164 app_handle: AppHandle<R>,
165 /// Per-launch invoke key (hex-encoded, 64 hex chars = 32 bytes).
166 invoke_key: String,
167 /// Raw invoke key bytes for constant-time comparison.
168 invoke_key_bytes: [u8; 32],
169}
170
171impl<R: Runtime> std::fmt::Debug for PluginState<R> {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("PluginState")
174 .field("channels", &self.channels.keys().collect::<Vec<_>>())
175 .field("invoke_key", &"[REDACTED]")
176 .finish()
177 }
178}
179
180impl<R: Runtime> PluginState<R> {
181 /// Get a channel by name (for pushing data from Rust handlers).
182 pub fn channel(&self, name: &str) -> Option<&Arc<ChannelBuffer>> {
183 self.channels.get(name)
184 }
185
186 /// Push binary data to a named channel and notify JS listeners.
187 ///
188 /// After writing to the channel, emits both a global
189 /// `conduit:data-available` event (payload = channel name) and a
190 /// per-channel `conduit:data-available:{channel}` event. JS subscribers
191 /// can listen on either.
192 ///
193 /// For lossy channels, oldest frames are silently dropped when the buffer
194 /// is full. For reliable channels, returns an error if the buffer is full
195 /// (backpressure).
196 ///
197 /// Returns an error string if the named channel was not registered via
198 /// the builder or if a reliable channel is full.
199 pub fn push(&self, channel: &str, data: &[u8]) -> Result<(), String> {
200 let ch = self
201 .channels
202 .get(channel)
203 .ok_or_else(|| format!("unknown channel: {channel}"))?;
204 ch.push(data).map(|_| ()).map_err(|e| e.to_string())?;
205 // Emit global event (backward-compatible with old JS code).
206 if self
207 .app_handle
208 .emit("conduit:data-available", channel)
209 .is_err()
210 {
211 #[cfg(debug_assertions)]
212 eprintln!(
213 "conduit: failed to emit global data-available event for channel '{channel}'"
214 );
215 }
216 // Emit per-channel event.
217 if self
218 .app_handle
219 .emit(&format!("conduit:data-available:{channel}"), channel)
220 .is_err()
221 {
222 #[cfg(debug_assertions)]
223 eprintln!(
224 "conduit: failed to emit per-channel data-available event for channel '{channel}'"
225 );
226 }
227 Ok(())
228 }
229
230 /// Return the list of registered channel names.
231 pub fn channel_names(&self) -> Vec<String> {
232 self.channels.keys().cloned().collect()
233 }
234
235 /// Validate an invoke key candidate using constant-time operations.
236 fn validate_invoke_key(&self, candidate: &str) -> bool {
237 validate_invoke_key_ct(&self.invoke_key_bytes, candidate)
238 }
239}
240
241// ---------------------------------------------------------------------------
242// Tauri commands
243// ---------------------------------------------------------------------------
244
245/// Return bootstrap info so the JS client knows how to reach the conduit
246/// custom protocol.
247///
248/// May be called multiple times (e.g., after page reloads during development).
249/// The invoke key is generated once at plugin setup and remains constant for
250/// the lifetime of the app process. Repeated calls return the same key.
251#[tauri::command]
252fn conduit_bootstrap(
253 state: tauri::State<'_, PluginState<tauri::Wry>>,
254) -> Result<BootstrapInfo, String> {
255 Ok(BootstrapInfo {
256 protocol_version: 1,
257 protocol_base: "conduit://localhost".to_string(),
258 invoke_key: state.invoke_key.clone(),
259 channels: state.channel_names(),
260 })
261}
262
263/// Validate channel names and return those that exist.
264///
265/// This is a validation-only endpoint — no server-side subscription state is
266/// tracked. The JS client uses the returned list to know which channels are
267/// available. Actual data delivery happens via `conduit:data-available` events
268/// and `conduit://localhost/drain/<channel>` protocol requests.
269///
270/// Unknown channel names are silently filtered out — only channels that
271/// exist are returned.
272#[tauri::command]
273fn conduit_subscribe(
274 state: tauri::State<'_, PluginState<tauri::Wry>>,
275 channels: Vec<String>,
276) -> Result<Vec<String>, String> {
277 // Silently filter to only channels that exist.
278 let valid: Vec<String> = channels
279 .into_iter()
280 .filter(|c| state.channels.contains_key(c.as_str()))
281 .collect();
282 Ok(valid)
283}
284
285// ---------------------------------------------------------------------------
286// Channel kind (internal)
287// ---------------------------------------------------------------------------
288
289/// Internal enum for deferred channel construction.
290enum ChannelKind {
291 /// Lossy ring buffer with the given byte capacity.
292 Lossy(usize),
293 /// Reliable queue with the given max byte limit.
294 Reliable(usize),
295}
296
297// ---------------------------------------------------------------------------
298// Plugin builder
299// ---------------------------------------------------------------------------
300
301/// A deferred command registration closure.
302type CommandRegistration = Box<dyn FnOnce(&Router) + Send>;
303
304/// Builder for the conduit Tauri v2 plugin.
305///
306/// Collects command registrations and configuration, then produces a
307/// [`TauriPlugin`] via [`build`](Self::build).
308pub struct PluginBuilder {
309 /// Deferred command registrations: (name, handler factory).
310 commands: Vec<CommandRegistration>,
311 /// `#[command]`-generated handlers (sync and async).
312 handler_defs: Vec<(String, Arc<dyn ConduitHandler>)>,
313 /// Named channels: (name, kind).
314 channel_defs: Vec<(String, ChannelKind)>,
315}
316
317impl std::fmt::Debug for PluginBuilder {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 f.debug_struct("PluginBuilder")
320 .field("commands", &self.commands.len())
321 .field("handlers", &self.handler_defs.len())
322 .field("channel_defs_count", &self.channel_defs.len())
323 .finish()
324 }
325}
326
327/// Validate that a channel name matches `[a-zA-Z0-9_-]+`.
328fn validate_channel_name(name: &str) {
329 assert!(
330 !name.is_empty()
331 && name
332 .bytes()
333 .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-'),
334 "conduit: invalid channel name '{}' — must match [a-zA-Z0-9_-]+",
335 name
336 );
337}
338
339/// Default channel capacity (64 KB).
340const DEFAULT_CHANNEL_CAPACITY: usize = 64 * 1024;
341
342impl PluginBuilder {
343 /// Panic if a channel with the given name is already registered.
344 fn assert_no_duplicate_channel(&self, name: &str) {
345 if self.channel_defs.iter().any(|(n, _)| n == name) {
346 panic!(
347 "conduit: duplicate channel name '{}' — each channel must have a unique name",
348 name
349 );
350 }
351 }
352
353 /// Create a new, empty plugin builder.
354 pub fn new() -> Self {
355 Self {
356 commands: Vec::new(),
357 handler_defs: Vec::new(),
358 channel_defs: Vec::new(),
359 }
360 }
361
362 // -- Raw handlers -------------------------------------------------------
363
364 /// Register a raw command handler (`Vec<u8>` in, `Vec<u8>` out).
365 ///
366 /// Command names correspond to the path segment in the
367 /// `conduit://localhost/invoke/<cmd_name>` URL.
368 pub fn command<F>(mut self, name: impl Into<String>, handler: F) -> Self
369 where
370 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
371 {
372 let name = name.into();
373 self.commands.push(Box::new(move |table: &Router| {
374 table.register(name, handler);
375 }));
376 self
377 }
378
379 // -- ConduitHandler-based (#[command]-generated, sync or async) ----------
380
381 /// Register a `#[tauri_conduit::command]`-generated handler.
382 ///
383 /// Works with both sync and async handlers. Sync handlers are dispatched
384 /// inline. Async handlers are spawned on the tokio runtime — truly async,
385 /// exactly like `#[tauri::command]`.
386 ///
387 /// ```rust,ignore
388 /// use tauri_conduit::{command, handler};
389 ///
390 /// #[command]
391 /// fn greet(name: String) -> String {
392 /// format!("Hello, {name}!")
393 /// }
394 ///
395 /// #[command]
396 /// async fn fetch_user(state: State<'_, Db>, id: u64) -> Result<User, String> {
397 /// state.get_user(id).await.map_err(|e| e.to_string())
398 /// }
399 ///
400 /// tauri_plugin_conduit::init()
401 /// .handler("greet", handler!(greet))
402 /// .handler("fetch_user", handler!(fetch_user))
403 /// .build()
404 /// ```
405 pub fn handler(mut self, name: impl Into<String>, handler: impl ConduitHandler) -> Self {
406 self.handler_defs.push((name.into(), Arc::new(handler)));
407 self
408 }
409
410 /// Register a raw closure handler (legacy API).
411 ///
412 /// Accepts the same closure signature as the pre-`ConduitHandler` `.handler()`:
413 /// `Fn(Vec<u8>, &dyn Any) -> Result<Vec<u8>, Error>`. This is a synchronous
414 /// handler dispatched via `Router::register_with_context`.
415 ///
416 /// Use this for backward compatibility when migrating from closure-based
417 /// registration. For new code, prefer [`handler`](Self::handler) with
418 /// `#[tauri_conduit::command]` + `handler!()`.
419 pub fn handler_raw<F>(mut self, name: impl Into<String>, handler: F) -> Self
420 where
421 F: Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, conduit_core::Error>
422 + Send
423 + Sync
424 + 'static,
425 {
426 let name = name.into();
427 self.commands.push(Box::new(move |table: &Router| {
428 table.register_with_context(name, handler);
429 }));
430 self
431 }
432
433 // -- JSON handlers (Level 1) --------------------------------------------
434
435 /// Typed JSON handler. Deserializes the request payload as `A` and
436 /// serializes the response as `R`.
437 ///
438 /// Unlike Tauri's `#[tauri::command]`, this takes a single argument type
439 /// (not named parameters) and does not support async or State injection.
440 ///
441 /// ```rust,ignore
442 /// .command_json("greet", |name: String| format!("Hello, {name}!"))
443 /// ```
444 pub fn command_json<F, A, R>(mut self, name: impl Into<String>, handler: F) -> Self
445 where
446 F: Fn(A) -> R + Send + Sync + 'static,
447 A: serde::de::DeserializeOwned + 'static,
448 R: serde::Serialize + 'static,
449 {
450 let name = name.into();
451 self.commands.push(Box::new(move |table: &Router| {
452 table.register_json(name, handler);
453 }));
454 self
455 }
456
457 /// Typed JSON handler that returns `Result<R, E>`.
458 ///
459 /// Like [`command_json`](Self::command_json), but the handler returns
460 /// `Result<R, E>` where `E: Display`. On success, `R` is serialized to
461 /// JSON. On error, the error's `Display` text is returned to the caller.
462 ///
463 /// For Tauri-style named parameters with `Result` returns, prefer
464 /// [`handler`](Self::handler) with `#[tauri_conduit::command]` instead:
465 ///
466 /// ```rust,ignore
467 /// use tauri_conduit::command;
468 ///
469 /// #[command]
470 /// fn divide(a: f64, b: f64) -> Result<f64, String> {
471 /// if b == 0.0 { Err("division by zero".into()) }
472 /// else { Ok(a / b) }
473 /// }
474 ///
475 /// // Preferred:
476 /// .handler("divide", divide)
477 /// ```
478 pub fn command_json_result<F, A, R, E>(mut self, name: impl Into<String>, handler: F) -> Self
479 where
480 F: Fn(A) -> Result<R, E> + Send + Sync + 'static,
481 A: serde::de::DeserializeOwned + 'static,
482 R: serde::Serialize + 'static,
483 E: std::fmt::Display + 'static,
484 {
485 let name = name.into();
486 self.commands.push(Box::new(move |table: &Router| {
487 table.register_json_result(name, handler);
488 }));
489 self
490 }
491
492 // -- Binary handlers (Level 2) ------------------------------------------
493
494 /// Register a typed binary command handler.
495 ///
496 /// The request payload is decoded via the [`Decode`] trait and the response
497 /// is encoded via [`Encode`]. No JSON involved — raw bytes in, raw bytes
498 /// out.
499 ///
500 /// ```rust,ignore
501 /// .command_binary("process", |tick: MarketTick| tick)
502 /// ```
503 pub fn command_binary<F, A, Ret>(mut self, name: impl Into<String>, handler: F) -> Self
504 where
505 F: Fn(A) -> Ret + Send + Sync + 'static,
506 A: Decode + 'static,
507 Ret: Encode + 'static,
508 {
509 let name = name.into();
510 self.commands.push(Box::new(move |table: &Router| {
511 table.register_binary(name, handler);
512 }));
513 self
514 }
515
516 // -- Lossy channels (default) -------------------------------------------
517
518 /// Register a lossy channel with the default capacity (64 KB).
519 ///
520 /// Oldest frames are silently dropped when the buffer is full. Best for
521 /// telemetry, game state, and real-time data where freshness matters more
522 /// than completeness.
523 ///
524 /// # Panics
525 ///
526 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
527 /// or duplicates an already-registered channel name.
528 pub fn channel(mut self, name: impl Into<String>) -> Self {
529 let name = name.into();
530 validate_channel_name(&name);
531 self.assert_no_duplicate_channel(&name);
532 self.channel_defs
533 .push((name, ChannelKind::Lossy(DEFAULT_CHANNEL_CAPACITY)));
534 self
535 }
536
537 /// Register a lossy channel with a custom byte capacity.
538 ///
539 /// # Panics
540 ///
541 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
542 /// or duplicates an already-registered channel name.
543 pub fn channel_with_capacity(mut self, name: impl Into<String>, capacity: usize) -> Self {
544 let name = name.into();
545 validate_channel_name(&name);
546 self.assert_no_duplicate_channel(&name);
547 self.channel_defs.push((name, ChannelKind::Lossy(capacity)));
548 self
549 }
550
551 // -- Reliable channels (guaranteed delivery) ----------------------------
552
553 /// Register an ordered channel with the default capacity (64 KB).
554 ///
555 /// No frames are ever dropped. When the buffer is full,
556 /// [`PluginState::push`] returns an error (backpressure). Best for
557 /// transaction logs, control messages, and any data that must arrive
558 /// intact and in order.
559 ///
560 /// # Panics
561 ///
562 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
563 /// or duplicates an already-registered channel name.
564 pub fn channel_ordered(mut self, name: impl Into<String>) -> Self {
565 let name = name.into();
566 validate_channel_name(&name);
567 self.assert_no_duplicate_channel(&name);
568 self.channel_defs
569 .push((name, ChannelKind::Reliable(DEFAULT_CHANNEL_CAPACITY)));
570 self
571 }
572
573 /// Register an ordered channel with a custom byte limit.
574 ///
575 /// A `max_bytes` of `0` means unbounded — the buffer grows without limit.
576 ///
577 /// # Panics
578 ///
579 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
580 /// or duplicates an already-registered channel name.
581 pub fn channel_ordered_with_capacity(
582 mut self,
583 name: impl Into<String>,
584 max_bytes: usize,
585 ) -> Self {
586 let name = name.into();
587 validate_channel_name(&name);
588 self.assert_no_duplicate_channel(&name);
589 self.channel_defs
590 .push((name, ChannelKind::Reliable(max_bytes)));
591 self
592 }
593
594 // -- Build --------------------------------------------------------------
595
596 /// Build the Tauri v2 plugin.
597 ///
598 /// This consumes the builder and returns a [`TauriPlugin`] that can be
599 /// passed to `tauri::Builder::plugin`.
600 ///
601 /// # Dispatch model
602 ///
603 /// Commands are dispatched through a two-tier system:
604 ///
605 /// 1. **`#[command]` handlers** (registered via [`.handler()`](Self::handler))
606 /// are checked first. These support named parameters, `State<T>` injection,
607 /// `Result` returns, and async — full parity with `#[tauri::command]`.
608 ///
609 /// 2. **Raw Router handlers** (registered via [`.command()`](Self::command),
610 /// [`.command_json()`](Self::command_json), [`.command_binary()`](Self::command_binary))
611 /// are the fallback. These are simpler `Vec<u8> -> Vec<u8>` functions
612 /// with no injection or async support.
613 ///
614 /// If a command name exists in both tiers, the `#[command]` handler takes
615 /// priority and a debug warning is printed.
616 pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
617 let commands = self.commands;
618 let handler_defs = self.handler_defs;
619 let channel_defs = self.channel_defs;
620
621 TauriPluginBuilder::<R>::new("conduit")
622 // --- Custom protocol: conduit://localhost/invoke/<cmd> ---
623 // Uses the asynchronous variant so async #[command] handlers
624 // are spawned on tokio (truly async, like #[tauri::command]).
625 .register_asynchronous_uri_scheme_protocol("conduit", move |ctx, request, responder| {
626 // Handle CORS preflight requests.
627 if request.method() == "OPTIONS" {
628 let resp = http::Response::builder()
629 .status(204)
630 .header("Access-Control-Allow-Origin", "*")
631 .header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
632 .header(
633 "Access-Control-Allow-Headers",
634 "Content-Type, X-Conduit-Key, X-Conduit-Webview",
635 )
636 .header("Access-Control-Max-Age", "86400")
637 .body(Vec::new())
638 .expect("preflight response must not fail");
639 responder.respond(resp);
640 return;
641 }
642
643 // Extract the managed PluginState from the app handle.
644 let state: tauri::State<'_, PluginState<R>> = ctx.app_handle().state();
645
646 let url = request.uri().to_string();
647
648 // Extract path from URL: conduit://localhost/{action}/{target}
649 // Use simple string splitting instead of full URL parsing —
650 // the format is fixed and under our control.
651 let path = url
652 .find("://")
653 .and_then(|i| url[i + 3..].find('/'))
654 .map(|i| {
655 let host_end = url.find("://").unwrap() + 3;
656 &url[host_end + i..]
657 })
658 .unwrap_or("/");
659 let segments: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
660
661 if segments.len() != 2 {
662 responder.respond(make_error_response(
663 404,
664 "not found: expected /invoke/<cmd> or /drain/<channel>",
665 ));
666 return;
667 }
668
669 // Validate the invoke key from the X-Conduit-Key header.
670 let key = match request.headers().get("X-Conduit-Key") {
671 Some(v) => match v.to_str() {
672 Ok(s) => s.to_string(),
673 Err(_) => {
674 responder
675 .respond(make_error_response(401, "invalid invoke key header"));
676 return;
677 }
678 },
679 None => {
680 responder.respond(make_error_response(401, "missing invoke key"));
681 return;
682 }
683 };
684
685 if !state.validate_invoke_key(&key) {
686 responder.respond(make_error_response(403, "invalid invoke key"));
687 return;
688 }
689
690 let action = segments[0];
691 let raw_target = segments[1];
692
693 // H6: Percent-decode the target and reject path traversal.
694 let target = percent_decode(raw_target);
695 if target.contains('/') {
696 responder.respond(make_error_response(400, "invalid command name"));
697 return;
698 }
699
700 match action {
701 "invoke" => {
702 let body = request.body().to_vec();
703
704 // 1) Check #[command]-generated handlers first (sync or async)
705 if let Some(handler) = state.handlers.get(&target) {
706 let handler = Arc::clone(handler);
707 let app_handle = ctx.app_handle().clone();
708 // Extract webview label from X-Conduit-Webview header (sent by JS client).
709 // NOTE: This header is client-provided and could be spoofed by JS
710 // running in the same webview. We validate the format to prevent
711 // injection attacks, but in a multi-webview app, code in one
712 // webview could impersonate another. This matches Tauri's own
713 // trust model where all JS in the webview is equally trusted.
714 let webview_label = request
715 .headers()
716 .get("X-Conduit-Webview")
717 .and_then(|v| v.to_str().ok())
718 .filter(|s| {
719 !s.is_empty()
720 && s.len() <= 128
721 && s.bytes().all(|b| {
722 b.is_ascii_alphanumeric() || b == b'_' || b == b'-'
723 })
724 })
725 .map(|s| s.to_string());
726 let handler_ctx = conduit_core::HandlerContext::new(
727 Arc::new(app_handle),
728 webview_label,
729 );
730 let ctx_any: Arc<dyn std::any::Any + Send + Sync> =
731 Arc::new(handler_ctx);
732
733 // SAFETY: AssertUnwindSafe is used here because:
734 // - `body` is a Vec<u8> (unwind-safe by itself)
735 // - `ctx_any` is an Arc (unwind-safe)
736 // - conduit's own locks use poison-recovery helpers (lock_or_recover)
737 // - User-defined handler state may be left inconsistent after panic,
738 // but this is inherent to catch_unwind and documented as a limitation.
739 let result =
740 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
741 handler.call(body, ctx_any)
742 }));
743
744 match result {
745 Ok(HandlerResponse::Sync(Ok(bytes))) => {
746 responder.respond(make_response(
747 200,
748 "application/octet-stream",
749 bytes,
750 ));
751 }
752 Ok(HandlerResponse::Sync(Err(e))) => {
753 let status = error_to_status(&e);
754 responder
755 .respond(make_error_response(status, &sanitize_error(&e)));
756 }
757 Ok(HandlerResponse::Async(future)) => {
758 // Truly async — spawned on tokio, just like #[tauri::command].
759 // Inner spawn provides panic isolation: if the future panics
760 // during execution, the JoinHandle catches it and we respond
761 // with a 500 instead of leaving the request hanging.
762 tauri::async_runtime::spawn(async move {
763 let handle = tauri::async_runtime::spawn(future);
764 match handle.await {
765 Ok(Ok(bytes)) => {
766 responder.respond(make_response(
767 200,
768 "application/octet-stream",
769 bytes,
770 ));
771 }
772 Ok(Err(e)) => {
773 let status = error_to_status(&e);
774 responder.respond(make_error_response(
775 status,
776 &sanitize_error(&e),
777 ));
778 }
779 Err(_) => {
780 // Panic during async handler execution
781 responder.respond(make_error_response(
782 500,
783 "handler panicked",
784 ));
785 }
786 }
787 });
788 }
789 Err(_) => {
790 // Panic caught by catch_unwind — keep as 500.
791 responder.respond(make_error_response(500, "handler panicked"));
792 }
793 }
794 } else {
795 // 2) Fall back to legacy sync Router
796 let dispatch = Arc::clone(&state.dispatch);
797 let app_handle = ctx.app_handle().clone();
798 // SAFETY: AssertUnwindSafe is used here because:
799 // - `body` is a Vec<u8> (unwind-safe by itself)
800 // - `dispatch` is an Arc<Router> (unwind-safe)
801 // - `app_handle` is a cloned AppHandle (unwind-safe)
802 // - conduit's own locks use poison-recovery helpers (lock_or_recover)
803 // - User-defined handler state may be left inconsistent after panic,
804 // but this is inherent to catch_unwind and documented as a limitation.
805 let result =
806 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
807 dispatch.call_with_context(&target, body, &app_handle)
808 }));
809 match result {
810 Ok(Ok(bytes)) => {
811 responder.respond(make_response(
812 200,
813 "application/octet-stream",
814 bytes,
815 ));
816 }
817 Ok(Err(e)) => {
818 let status = error_to_status(&e);
819 responder
820 .respond(make_error_response(status, &sanitize_error(&e)));
821 }
822 Err(_) => {
823 // Panic caught by catch_unwind — keep as 500.
824 responder.respond(make_error_response(500, "handler panicked"));
825 }
826 }
827 }
828 }
829 "drain" => match state.channel(&target) {
830 Some(ch) => {
831 let blob = ch.drain_all();
832 responder.respond(make_response(200, "application/octet-stream", blob));
833 }
834 None => {
835 responder.respond(make_error_response(
836 404,
837 &format!("unknown channel: {}", sanitize_name(&target)),
838 ));
839 }
840 },
841 _ => {
842 responder.respond(make_error_response(
843 404,
844 "not found: expected /invoke/<cmd> or /drain/<channel>",
845 ));
846 }
847 }
848 })
849 // --- Register Tauri IPC commands ---
850 .invoke_handler(tauri::generate_handler![
851 conduit_bootstrap,
852 conduit_subscribe,
853 ])
854 // --- Plugin setup: create state, register commands ---
855 .setup(move |app, _api| {
856 let dispatch = Arc::new(Router::new());
857
858 // Register all old-style commands that were added via the builder.
859 for register_fn in commands {
860 register_fn(&dispatch);
861 }
862
863 // Build the #[command] handler map, checking for collisions
864 // with Router commands.
865 let mut handler_map = HashMap::new();
866 for (name, handler) in handler_defs {
867 if dispatch.has(&name) {
868 #[cfg(debug_assertions)]
869 eprintln!(
870 "conduit: warning: handler '{name}' shadows a Router command \
871 with the same name — the #[command] handler takes priority"
872 );
873 }
874 handler_map.insert(name, handler);
875 }
876 let handlers = Arc::new(handler_map);
877
878 // Create named channels.
879 let mut channels = HashMap::new();
880 for (name, kind) in channel_defs {
881 let buf = match kind {
882 ChannelKind::Lossy(cap) => ChannelBuffer::Lossy(RingBuffer::new(cap)),
883 ChannelKind::Reliable(max_bytes) => {
884 ChannelBuffer::Reliable(Queue::new(max_bytes))
885 }
886 };
887 channels.insert(name, Arc::new(buf));
888 }
889
890 // Generate the per-launch invoke key.
891 let invoke_key_bytes = generate_invoke_key_bytes();
892 let invoke_key = hex_encode(&invoke_key_bytes);
893
894 // Obtain the app handle for emitting events.
895 let app_handle = app.app_handle().clone();
896
897 let state = PluginState {
898 dispatch,
899 handlers,
900 channels,
901 app_handle,
902 invoke_key,
903 invoke_key_bytes,
904 };
905
906 app.manage(state);
907
908 Ok(())
909 })
910 .build()
911 }
912}
913
914impl Default for PluginBuilder {
915 fn default() -> Self {
916 Self::new()
917 }
918}
919
920// ---------------------------------------------------------------------------
921// Public init function
922// ---------------------------------------------------------------------------
923
924/// Create a new conduit plugin builder.
925///
926/// This is the main entry point for using the conduit Tauri plugin:
927///
928/// ```rust,ignore
929/// use tauri_conduit::command;
930///
931/// #[command]
932/// fn greet(name: String) -> String {
933/// format!("Hello, {name}!")
934/// }
935///
936/// #[command]
937/// async fn fetch_data(url: String) -> Result<Vec<u8>, String> {
938/// reqwest::get(&url).await.map_err(|e| e.to_string())?
939/// .bytes().await.map(|b| b.to_vec()).map_err(|e| e.to_string())
940/// }
941///
942/// tauri::Builder::default()
943/// .plugin(
944/// tauri_plugin_conduit::init()
945/// .handler("greet", handler!(greet))
946/// .handler("fetch_data", handler!(fetch_data))
947/// .channel("telemetry")
948/// .build()
949/// )
950/// .run(tauri::generate_context!())
951/// .unwrap();
952/// ```
953pub fn init() -> PluginBuilder {
954 PluginBuilder::new()
955}
956
957// ---------------------------------------------------------------------------
958// Helpers
959// ---------------------------------------------------------------------------
960
961/// Map a [`conduit_core::Error`] to the appropriate HTTP status code.
962fn error_to_status(e: &conduit_core::Error) -> u16 {
963 match e {
964 conduit_core::Error::UnknownCommand(_) => 404,
965 conduit_core::Error::UnknownChannel(_) => 404,
966 conduit_core::Error::AuthFailed => 403,
967 conduit_core::Error::DecodeFailed => 400,
968 conduit_core::Error::PayloadTooLarge(_) => 413,
969 conduit_core::Error::Handler(_) => 500,
970 conduit_core::Error::Serialize(_) => 500,
971 conduit_core::Error::ChannelFull => 500,
972 }
973}
974
975/// Truncate a user-supplied name to 64 bytes and strip control characters
976/// to prevent log injection and oversized error messages.
977///
978/// Truncation respects UTF-8 character boundaries — the output is always
979/// valid UTF-8 with at most 64 bytes of text content.
980fn sanitize_name(name: &str) -> String {
981 let truncated = if name.len() > 64 {
982 // Walk back from byte 64 to find a valid char boundary.
983 let mut end = 64;
984 while end > 0 && !name.is_char_boundary(end) {
985 end -= 1;
986 }
987 &name[..end]
988 } else {
989 name
990 };
991 truncated.chars().filter(|c| !c.is_control()).collect()
992}
993
994/// Format a [`conduit_core::Error`] for inclusion in HTTP error responses,
995/// sanitizing any embedded user-supplied names (command or channel names).
996fn sanitize_error(e: &conduit_core::Error) -> String {
997 match e {
998 conduit_core::Error::UnknownCommand(name) => {
999 format!("unknown command: {}", sanitize_name(name))
1000 }
1001 conduit_core::Error::UnknownChannel(name) => {
1002 format!("unknown channel: {}", sanitize_name(name))
1003 }
1004 other => other.to_string(),
1005 }
1006}
1007
1008/// Percent-decode a URL path segment (e.g., `hello%20world` → `hello world`).
1009///
1010/// This is a minimal implementation — no new dependency needed. Only `%XX`
1011/// sequences with valid hex digits are decoded; all other bytes pass through
1012/// unchanged.
1013fn percent_decode(input: &str) -> String {
1014 let mut result = Vec::with_capacity(input.len());
1015 let bytes = input.as_bytes();
1016 let mut i = 0;
1017 while i < bytes.len() {
1018 if bytes[i] == b'%' && i + 2 < bytes.len() {
1019 if let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
1020 result.push(hi << 4 | lo);
1021 i += 3;
1022 continue;
1023 }
1024 }
1025 result.push(bytes[i]);
1026 i += 1;
1027 }
1028 String::from_utf8_lossy(&result).into_owned()
1029}
1030
1031/// Convert a single ASCII hex character to its 4-bit numeric value.
1032///
1033/// Unlike [`hex_digit_ct`], this does NOT need to be constant-time — it is
1034/// used for URL percent-decoding, not security-critical key validation.
1035fn hex_val(b: u8) -> Option<u8> {
1036 match b {
1037 b'0'..=b'9' => Some(b - b'0'),
1038 b'a'..=b'f' => Some(b - b'a' + 10),
1039 b'A'..=b'F' => Some(b - b'A' + 10),
1040 _ => None,
1041 }
1042}
1043
1044/// Generate 32 random bytes for the per-launch invoke key.
1045fn generate_invoke_key_bytes() -> [u8; 32] {
1046 let mut bytes = [0u8; 32];
1047 getrandom::fill(&mut bytes).expect("conduit: failed to generate invoke key");
1048 bytes
1049}
1050
1051/// Hex-encode a byte slice (no per-byte allocation).
1052fn hex_encode(bytes: &[u8]) -> String {
1053 const HEX: &[u8; 16] = b"0123456789abcdef";
1054 let mut hex = String::with_capacity(bytes.len() * 2);
1055 for &b in bytes {
1056 hex.push(HEX[(b >> 4) as usize] as char);
1057 hex.push(HEX[(b & 0x0f) as usize] as char);
1058 }
1059 hex
1060}
1061
1062/// Hex-decode a string into bytes. Returns `None` on invalid input.
1063///
1064/// This is the non-constant-time version used for non-security paths.
1065/// For invoke key validation, see [`hex_digit_ct`] and the constant-time
1066/// path in [`PluginState::validate_invoke_key`].
1067#[cfg(test)]
1068fn hex_decode(hex: &str) -> Option<Vec<u8>> {
1069 if hex.len() % 2 != 0 {
1070 return None;
1071 }
1072 let mut bytes = Vec::with_capacity(hex.len() / 2);
1073 for chunk in hex.as_bytes().chunks(2) {
1074 let hi = hex_digit(chunk[0])?;
1075 let lo = hex_digit(chunk[1])?;
1076 bytes.push((hi << 4) | lo);
1077 }
1078 Some(bytes)
1079}
1080
1081/// Convert a single ASCII hex character to its 4-bit numeric value.
1082#[cfg(test)]
1083fn hex_digit(b: u8) -> Option<u8> {
1084 match b {
1085 b'0'..=b'9' => Some(b - b'0'),
1086 b'a'..=b'f' => Some(b - b'a' + 10),
1087 b'A'..=b'F' => Some(b - b'A' + 10),
1088 _ => None,
1089 }
1090}
1091
1092/// Validate an invoke key candidate using constant-time operations.
1093///
1094/// The length check (must be exactly 64 hex chars) is not constant-time
1095/// because the expected length is public knowledge. The hex decode and
1096/// byte comparison are fully constant-time: no early returns on invalid
1097/// characters, and the comparison always runs even if decode failed.
1098fn validate_invoke_key_ct(expected: &[u8; 32], candidate: &str) -> bool {
1099 let candidate_bytes = candidate.as_bytes();
1100
1101 // Length is not secret — always 64 hex chars for 32 bytes.
1102 if candidate_bytes.len() != 64 {
1103 return false;
1104 }
1105
1106 // Constant-time hex decode: always process all 32 byte pairs.
1107 let mut decoded = [0u8; 32];
1108 let mut all_valid = 1u8;
1109
1110 for i in 0..32 {
1111 let (hi_val, hi_ok) = hex_digit_ct(candidate_bytes[i * 2]);
1112 let (lo_val, lo_ok) = hex_digit_ct(candidate_bytes[i * 2 + 1]);
1113 decoded[i] = (hi_val << 4) | lo_val;
1114 all_valid &= hi_ok & lo_ok;
1115 }
1116
1117 // Always compare, even if some hex digits were invalid.
1118 let cmp_ok: bool = expected.ct_eq(&decoded).into();
1119
1120 // Combine with bitwise AND — no short-circuit.
1121 (all_valid == 1) & cmp_ok
1122}
1123
1124/// Constant-time hex digit decode for security-critical paths.
1125///
1126/// Returns `(value, valid)` where `valid` is `1` if the byte is a valid
1127/// hex character and `0` otherwise. All operations are bitwise — no
1128/// comparisons, no branches. When invalid, `value` is `0`.
1129///
1130/// Uses subtraction + sign-bit masking on `i16` to produce range-check
1131/// masks without any comparison operators that could compile to branches.
1132fn hex_digit_ct(b: u8) -> (u8, u8) {
1133 // Promote to i16 so wrapping_sub produces a sign bit we can extract.
1134 let b = b as i16;
1135
1136 // Check if b is in '0'..='9' (0x30..=0x39)
1137 let d = b.wrapping_sub(0x30); // b - '0'
1138 // d >= 0 && d < 10: (!d) is negative iff d >= 0; (d-10) is negative iff d < 10.
1139 // Combining via AND and extracting the sign bit gives us a mask.
1140 let digit_mask = ((!d) & (d.wrapping_sub(10))) >> 15;
1141 let digit_mask = (digit_mask & 1) as u8;
1142
1143 // Check if b is in 'a'..='f' (0x61..=0x66)
1144 let l = b.wrapping_sub(0x61); // b - 'a'
1145 let lower_mask = ((!l) & (l.wrapping_sub(6))) >> 15;
1146 let lower_mask = (lower_mask & 1) as u8;
1147
1148 // Check if b is in 'A'..='F' (0x41..=0x46)
1149 let u = b.wrapping_sub(0x41); // b - 'A'
1150 let upper_mask = ((!u) & (u.wrapping_sub(6))) >> 15;
1151 let upper_mask = (upper_mask & 1) as u8;
1152
1153 let val = ((d as u8 & 0x0f) & digit_mask.wrapping_neg())
1154 .wrapping_add((l as u8).wrapping_add(10) & lower_mask.wrapping_neg())
1155 .wrapping_add((u as u8).wrapping_add(10) & upper_mask.wrapping_neg());
1156 let valid = digit_mask | lower_mask | upper_mask;
1157
1158 (val, valid)
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164
1165 // -- hex encode/decode tests --
1166
1167 #[test]
1168 fn hex_encode_roundtrip() {
1169 assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
1170 assert_eq!(hex_encode(&[]), "");
1171 assert_eq!(hex_encode(&[0x00, 0xff]), "00ff");
1172 }
1173
1174 #[test]
1175 fn hex_decode_valid() {
1176 assert_eq!(hex_decode("deadbeef"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1177 assert_eq!(hex_decode(""), Some(vec![]));
1178 assert_eq!(hex_decode("00ff"), Some(vec![0x00, 0xff]));
1179 }
1180
1181 #[test]
1182 fn hex_decode_uppercase() {
1183 assert_eq!(hex_decode("DEADBEEF"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1184 assert_eq!(hex_decode("DeAdBeEf"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1185 }
1186
1187 #[test]
1188 fn hex_decode_odd_length() {
1189 assert_eq!(hex_decode("abc"), None);
1190 assert_eq!(hex_decode("a"), None);
1191 }
1192
1193 #[test]
1194 fn hex_decode_invalid_chars() {
1195 assert_eq!(hex_decode("zz"), None);
1196 assert_eq!(hex_decode("gg"), None);
1197 assert_eq!(hex_decode("0x"), None);
1198 }
1199
1200 #[test]
1201 fn hex_roundtrip_32_bytes() {
1202 let original = generate_invoke_key_bytes();
1203 let encoded = hex_encode(&original);
1204 assert_eq!(encoded.len(), 64);
1205 let decoded = hex_decode(&encoded).unwrap();
1206 assert_eq!(decoded, original);
1207 }
1208
1209 // -- constant-time hex tests --
1210
1211 #[test]
1212 fn hex_digit_ct_valid_chars() {
1213 for b in b'0'..=b'9' {
1214 let (val, valid) = hex_digit_ct(b);
1215 assert_eq!(valid, 1, "digit {b} should be valid");
1216 assert_eq!(val, b - b'0');
1217 }
1218 for b in b'a'..=b'f' {
1219 let (val, valid) = hex_digit_ct(b);
1220 assert_eq!(valid, 1, "lower {b} should be valid");
1221 assert_eq!(val, b - b'a' + 10);
1222 }
1223 for b in b'A'..=b'F' {
1224 let (val, valid) = hex_digit_ct(b);
1225 assert_eq!(valid, 1, "upper {b} should be valid");
1226 assert_eq!(val, b - b'A' + 10);
1227 }
1228 }
1229
1230 #[test]
1231 fn hex_digit_ct_invalid_chars() {
1232 for &b in &[b'g', b'z', b'G', b'Z', b' ', b'\0', b'/', b':', b'@', b'`'] {
1233 let (_val, valid) = hex_digit_ct(b);
1234 assert_eq!(valid, 0, "char {b} should be invalid");
1235 }
1236 }
1237
1238 #[test]
1239 fn hex_digit_ct_matches_hex_digit() {
1240 for b in 0..=255u8 {
1241 let ct_result = hex_digit_ct(b);
1242 let std_result = hex_digit(b);
1243 match std_result {
1244 Some(v) => {
1245 assert_eq!(ct_result.1, 1, "mismatch at {b}: ct says invalid");
1246 assert_eq!(ct_result.0, v, "value mismatch at {b}");
1247 }
1248 None => {
1249 assert_eq!(ct_result.1, 0, "mismatch at {b}: ct says valid");
1250 }
1251 }
1252 }
1253 }
1254
1255 // -- make_response tests --
1256
1257 #[test]
1258 fn make_response_200() {
1259 let resp = make_response(200, "application/octet-stream", b"hello".to_vec());
1260 assert_eq!(resp.status(), 200);
1261 assert_eq!(resp.body(), b"hello");
1262 }
1263
1264 #[test]
1265 fn make_response_404() {
1266 let resp = make_response(404, "text/plain", b"not found".to_vec());
1267 assert_eq!(resp.status(), 404);
1268 assert_eq!(resp.body(), b"not found");
1269 }
1270
1271 // -- State<T> injection tests --
1272
1273 #[command]
1274 fn with_state(state: tauri::State<'_, String>, name: String) -> String {
1275 format!("{}: {name}", state.as_str())
1276 }
1277
1278 #[test]
1279 fn state_injection_wrong_context_returns_error() {
1280 use conduit_core::ConduitHandler;
1281 use conduit_derive::handler;
1282
1283 let payload = serde_json::to_vec(&serde_json::json!({ "name": "test" })).unwrap();
1284 let wrong_ctx: Arc<dyn std::any::Any + Send + Sync> = Arc::new(());
1285
1286 match handler!(with_state).call(payload, wrong_ctx) {
1287 conduit_core::HandlerResponse::Sync(Err(conduit_core::Error::Handler(msg))) => {
1288 assert!(
1289 msg.contains("handler context must be HandlerContext"),
1290 "unexpected error message: {msg}"
1291 );
1292 }
1293 _ => panic!("expected Sync(Err(Handler))"),
1294 }
1295 }
1296
1297 #[test]
1298 fn original_state_function_preserved() {
1299 // The original function with_state is preserved and callable directly.
1300 // We can't call it without an actual Tauri State, but we can verify
1301 // the function exists and has the right signature by taking a reference.
1302 let _fn_ref: fn(tauri::State<'_, String>, String) -> String = with_state;
1303 }
1304
1305 // -- validate_invoke_key tests --
1306
1307 #[test]
1308 fn validate_invoke_key_correct() {
1309 let key = [0xab_u8; 32];
1310 let hex = hex_encode(&key);
1311 assert!(validate_invoke_key_ct(&key, &hex));
1312 }
1313
1314 #[test]
1315 fn validate_invoke_key_wrong_key() {
1316 let key = [0xab_u8; 32];
1317 let wrong = hex_encode(&[0x00_u8; 32]);
1318 assert!(!validate_invoke_key_ct(&key, &wrong));
1319 }
1320
1321 #[test]
1322 fn validate_invoke_key_wrong_length() {
1323 let key = [0xab_u8; 32];
1324 assert!(!validate_invoke_key_ct(&key, "abcdef"));
1325 assert!(!validate_invoke_key_ct(&key, ""));
1326 assert!(!validate_invoke_key_ct(&key, &"a".repeat(63)));
1327 assert!(!validate_invoke_key_ct(&key, &"a".repeat(65)));
1328 }
1329
1330 #[test]
1331 fn validate_invoke_key_invalid_hex() {
1332 let key = [0xab_u8; 32];
1333 // 64 chars but invalid hex
1334 assert!(!validate_invoke_key_ct(&key, &"zz".repeat(32)));
1335 assert!(!validate_invoke_key_ct(&key, &"gg".repeat(32)));
1336 }
1337
1338 #[test]
1339 fn validate_invoke_key_uppercase_accepted() {
1340 let key = [0xab_u8; 32];
1341 let hex = hex_encode(&key);
1342 // hex_digit_ct handles uppercase, so uppercase of a valid key should match
1343 assert!(validate_invoke_key_ct(&key, &hex.to_uppercase()));
1344 }
1345
1346 #[test]
1347 fn validate_invoke_key_random_roundtrip() {
1348 let key = generate_invoke_key_bytes();
1349 let hex = hex_encode(&key);
1350 assert!(validate_invoke_key_ct(&key, &hex));
1351 }
1352
1353 // -- make_error_response tests --
1354
1355 #[test]
1356 fn make_error_response_json_format() {
1357 let resp = make_error_response(500, "something failed");
1358 assert_eq!(resp.status(), 500);
1359 let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
1360 assert_eq!(body["error"], "something failed");
1361 }
1362
1363 #[test]
1364 fn make_error_response_escapes_special_chars() {
1365 let resp = make_error_response(400, r#"bad "input" with \ slash"#);
1366 let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
1367 assert_eq!(body["error"], r#"bad "input" with \ slash"#);
1368 }
1369
1370 // -- percent_decode tests --
1371
1372 #[test]
1373 fn percent_decode_no_encoding() {
1374 assert_eq!(percent_decode("hello"), "hello");
1375 assert_eq!(percent_decode("foo-bar_baz"), "foo-bar_baz");
1376 }
1377
1378 #[test]
1379 fn percent_decode_basic() {
1380 assert_eq!(percent_decode("hello%20world"), "hello world");
1381 assert_eq!(percent_decode("%2F"), "/");
1382 assert_eq!(percent_decode("%2f"), "/");
1383 }
1384
1385 #[test]
1386 fn percent_decode_multiple() {
1387 assert_eq!(percent_decode("a%20b%20c"), "a b c");
1388 assert_eq!(percent_decode("%41%42%43"), "ABC");
1389 }
1390
1391 #[test]
1392 fn percent_decode_incomplete_sequence() {
1393 // Incomplete %XX at end — pass through unchanged.
1394 assert_eq!(percent_decode("hello%2"), "hello%2");
1395 assert_eq!(percent_decode("hello%"), "hello%");
1396 }
1397
1398 #[test]
1399 fn percent_decode_invalid_hex() {
1400 // Invalid hex chars after % — pass through unchanged.
1401 assert_eq!(percent_decode("hello%GG"), "hello%GG");
1402 assert_eq!(percent_decode("%ZZ"), "%ZZ");
1403 }
1404
1405 #[test]
1406 fn percent_decode_empty() {
1407 assert_eq!(percent_decode(""), "");
1408 }
1409
1410 // -- sanitize_name tests --
1411
1412 #[test]
1413 fn sanitize_name_short() {
1414 assert_eq!(sanitize_name("hello"), "hello");
1415 }
1416
1417 #[test]
1418 fn sanitize_name_truncates_long() {
1419 let long = "a".repeat(100);
1420 assert_eq!(sanitize_name(&long).len(), 64);
1421 }
1422
1423 #[test]
1424 fn sanitize_name_strips_control_chars() {
1425 assert_eq!(sanitize_name("hello\x00world"), "helloworld");
1426 assert_eq!(sanitize_name("foo\nbar\rbaz"), "foobarbaz");
1427 }
1428
1429 #[test]
1430 fn sanitize_name_multibyte_utf8() {
1431 // "a" repeated 63 times + "é" (2 bytes: 0xC3 0xA9) = 65 bytes total.
1432 // Byte 64 is the second byte of "é", not a char boundary.
1433 // Must not panic — should truncate to the last valid boundary (63 'a's).
1434 let name = format!("{}{}", "a".repeat(63), "é");
1435 assert_eq!(name.len(), 65);
1436 let sanitized = sanitize_name(&name);
1437 assert_eq!(sanitized, "a".repeat(63));
1438
1439 // 4-byte character crossing the 64-byte boundary.
1440 let name = format!("{}🦀", "a".repeat(62)); // 62 + 4 = 66 bytes
1441 assert_eq!(name.len(), 66);
1442 let sanitized = sanitize_name(&name);
1443 assert_eq!(sanitized, "a".repeat(62));
1444
1445 // Exactly 64 bytes of ASCII — no truncation needed.
1446 let name = "a".repeat(64);
1447 assert_eq!(sanitize_name(&name), "a".repeat(64));
1448 }
1449
1450 // -- error_to_status tests --
1451
1452 #[test]
1453 fn error_to_status_mapping() {
1454 use conduit_core::Error;
1455 assert_eq!(error_to_status(&Error::UnknownCommand("x".into())), 404);
1456 assert_eq!(error_to_status(&Error::UnknownChannel("x".into())), 404);
1457 assert_eq!(error_to_status(&Error::AuthFailed), 403);
1458 assert_eq!(error_to_status(&Error::DecodeFailed), 400);
1459 assert_eq!(error_to_status(&Error::PayloadTooLarge(999)), 413);
1460 assert_eq!(error_to_status(&Error::Handler("x".into())), 500);
1461 assert_eq!(error_to_status(&Error::ChannelFull), 500);
1462 }
1463
1464 // -- channel validation tests --
1465
1466 #[test]
1467 fn validate_channel_name_valid() {
1468 validate_channel_name("telemetry");
1469 validate_channel_name("my-channel");
1470 validate_channel_name("my_channel");
1471 validate_channel_name("Channel123");
1472 validate_channel_name("a");
1473 }
1474
1475 #[test]
1476 #[should_panic(expected = "invalid channel name")]
1477 fn validate_channel_name_empty() {
1478 validate_channel_name("");
1479 }
1480
1481 #[test]
1482 #[should_panic(expected = "invalid channel name")]
1483 fn validate_channel_name_spaces() {
1484 validate_channel_name("my channel");
1485 }
1486
1487 #[test]
1488 #[should_panic(expected = "invalid channel name")]
1489 fn validate_channel_name_special_chars() {
1490 validate_channel_name("my.channel");
1491 }
1492
1493 #[test]
1494 #[should_panic(expected = "duplicate channel name")]
1495 fn duplicate_channel_panics() {
1496 PluginBuilder::new()
1497 .channel("telemetry")
1498 .channel("telemetry");
1499 }
1500
1501 #[test]
1502 #[should_panic(expected = "duplicate channel name")]
1503 fn duplicate_channel_different_kinds_panics() {
1504 PluginBuilder::new().channel("data").channel_ordered("data");
1505 }
1506}