#![forbid(unsafe_code)]
#![deny(missing_docs)]
pub use conduit_derive::command;
pub use conduit_derive::handler;
use std::collections::HashMap;
use std::sync::Arc;
use conduit_core::{
ChannelBuffer, ConduitHandler, Decode, Encode, HandlerResponse, Queue, RingBuffer, Router,
};
use futures_util::FutureExt;
use subtle::ConstantTimeEq;
use tauri::plugin::{Builder as TauriPluginBuilder, TauriPlugin};
use tauri::{AppHandle, Emitter, Manager, Runtime};
fn make_response(status: u16, content_type: &str, body: Vec<u8>) -> http::Response<Vec<u8>> {
http::Response::builder()
.status(status)
.header("Content-Type", content_type)
.header("Access-Control-Allow-Origin", "*")
.body(body)
.unwrap_or_else(|_| {
http::Response::builder()
.status(500)
.body(b"internal error".to_vec())
.expect("fallback response must not fail")
})
}
fn make_error_response(status: u16, message: &str) -> http::Response<Vec<u8>> {
#[derive(serde::Serialize)]
struct ErrorBody<'a> {
error: &'a str,
}
let body = conduit_core::sonic_rs::to_vec(&ErrorBody { error: message })
.unwrap_or_else(|_| br#"{"error":"internal error"}"#.to_vec());
make_response(status, "application/json", body)
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BootstrapInfo {
#[serde(default = "default_protocol_version")]
pub protocol_version: u8,
pub protocol_base: String,
pub invoke_key: String,
pub channels: Vec<String>,
}
fn default_protocol_version() -> u8 {
1
}
impl std::fmt::Debug for BootstrapInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BootstrapInfo")
.field("protocol_version", &self.protocol_version)
.field("protocol_base", &self.protocol_base)
.field("invoke_key", &"[REDACTED]")
.field("channels", &self.channels)
.finish()
}
}
pub struct PluginState<R: Runtime> {
dispatch: Arc<Router>,
handlers: Arc<HashMap<String, Arc<dyn ConduitHandler>>>,
channels: HashMap<String, Arc<ChannelBuffer>>,
app_handle: AppHandle<R>,
app_handle_arc: Arc<AppHandle<R>>,
invoke_key: String,
invoke_key_bytes: [u8; 32],
}
impl<R: Runtime> std::fmt::Debug for PluginState<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PluginState")
.field("channels", &self.channels.keys().collect::<Vec<_>>())
.field("invoke_key", &"[REDACTED]")
.finish()
}
}
impl<R: Runtime> PluginState<R> {
pub fn channel(&self, name: &str) -> Option<&Arc<ChannelBuffer>> {
self.channels.get(name)
}
pub fn push(&self, channel: &str, data: &[u8]) -> Result<(), String> {
let ch = self
.channels
.get(channel)
.ok_or_else(|| format!("unknown channel: {channel}"))?;
ch.push(data).map(|_| ()).map_err(|e| e.to_string())?;
if self
.app_handle
.emit("conduit:data-available", channel)
.is_err()
{
#[cfg(debug_assertions)]
eprintln!(
"conduit: failed to emit global data-available event for channel '{channel}'"
);
}
if self
.app_handle
.emit(&format!("conduit:data-available:{channel}"), channel)
.is_err()
{
#[cfg(debug_assertions)]
eprintln!(
"conduit: failed to emit per-channel data-available event for channel '{channel}'"
);
}
Ok(())
}
pub fn channel_names(&self) -> Vec<String> {
self.channels.keys().cloned().collect()
}
fn validate_invoke_key(&self, candidate: &str) -> bool {
validate_invoke_key_ct(&self.invoke_key_bytes, candidate)
}
}
#[tauri::command]
fn conduit_bootstrap(
state: tauri::State<'_, PluginState<tauri::Wry>>,
) -> Result<BootstrapInfo, String> {
Ok(BootstrapInfo {
protocol_version: 1,
protocol_base: "conduit://localhost".to_string(),
invoke_key: state.invoke_key.clone(),
channels: state.channel_names(),
})
}
#[tauri::command]
fn conduit_subscribe(
state: tauri::State<'_, PluginState<tauri::Wry>>,
channels: Vec<String>,
) -> Result<Vec<String>, String> {
let valid: Vec<String> = channels
.into_iter()
.filter(|c| state.channels.contains_key(c.as_str()))
.collect();
Ok(valid)
}
enum ChannelKind {
Lossy(usize),
Reliable(usize),
}
type CommandRegistration = Box<dyn FnOnce(&Router) + Send>;
pub struct PluginBuilder {
commands: Vec<CommandRegistration>,
handler_defs: Vec<(String, Arc<dyn ConduitHandler>)>,
channel_defs: Vec<(String, ChannelKind)>,
}
impl std::fmt::Debug for PluginBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PluginBuilder")
.field("commands", &self.commands.len())
.field("handlers", &self.handler_defs.len())
.field("channel_defs_count", &self.channel_defs.len())
.finish()
}
}
fn validate_channel_name(name: &str) {
assert!(
!name.is_empty()
&& name
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-'),
"conduit: invalid channel name '{}' — must match [a-zA-Z0-9_-]+",
name
);
}
const DEFAULT_CHANNEL_CAPACITY: usize = 64 * 1024;
impl PluginBuilder {
fn assert_no_duplicate_channel(&self, name: &str) {
if self.channel_defs.iter().any(|(n, _)| n == name) {
panic!(
"conduit: duplicate channel name '{}' — each channel must have a unique name",
name
);
}
}
pub fn new() -> Self {
Self {
commands: Vec::new(),
handler_defs: Vec::new(),
channel_defs: Vec::new(),
}
}
pub fn command<F>(mut self, name: impl Into<String>, handler: F) -> Self
where
F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
{
let name = name.into();
self.commands.push(Box::new(move |table: &Router| {
table.register(name, handler);
}));
self
}
pub fn handler(mut self, name: impl Into<String>, handler: impl ConduitHandler) -> Self {
self.handler_defs.push((name.into(), Arc::new(handler)));
self
}
pub fn handler_raw<F>(mut self, name: impl Into<String>, handler: F) -> Self
where
F: Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, conduit_core::Error>
+ Send
+ Sync
+ 'static,
{
let name = name.into();
self.commands.push(Box::new(move |table: &Router| {
table.register_with_context(name, handler);
}));
self
}
pub fn command_json<F, A, R>(mut self, name: impl Into<String>, handler: F) -> Self
where
F: Fn(A) -> R + Send + Sync + 'static,
A: serde::de::DeserializeOwned + 'static,
R: serde::Serialize + 'static,
{
let name = name.into();
self.commands.push(Box::new(move |table: &Router| {
table.register_json(name, handler);
}));
self
}
pub fn command_json_result<F, A, R, E>(mut self, name: impl Into<String>, handler: F) -> Self
where
F: Fn(A) -> Result<R, E> + Send + Sync + 'static,
A: serde::de::DeserializeOwned + 'static,
R: serde::Serialize + 'static,
E: std::fmt::Display + 'static,
{
let name = name.into();
self.commands.push(Box::new(move |table: &Router| {
table.register_json_result(name, handler);
}));
self
}
pub fn command_binary<F, A, Ret>(mut self, name: impl Into<String>, handler: F) -> Self
where
F: Fn(A) -> Ret + Send + Sync + 'static,
A: Decode + 'static,
Ret: Encode + 'static,
{
let name = name.into();
self.commands.push(Box::new(move |table: &Router| {
table.register_binary(name, handler);
}));
self
}
pub fn channel(mut self, name: impl Into<String>) -> Self {
let name = name.into();
validate_channel_name(&name);
self.assert_no_duplicate_channel(&name);
self.channel_defs
.push((name, ChannelKind::Lossy(DEFAULT_CHANNEL_CAPACITY)));
self
}
pub fn channel_with_capacity(mut self, name: impl Into<String>, capacity: usize) -> Self {
let name = name.into();
validate_channel_name(&name);
self.assert_no_duplicate_channel(&name);
self.channel_defs.push((name, ChannelKind::Lossy(capacity)));
self
}
pub fn channel_ordered(mut self, name: impl Into<String>) -> Self {
let name = name.into();
validate_channel_name(&name);
self.assert_no_duplicate_channel(&name);
self.channel_defs
.push((name, ChannelKind::Reliable(DEFAULT_CHANNEL_CAPACITY)));
self
}
pub fn channel_ordered_with_capacity(
mut self,
name: impl Into<String>,
max_bytes: usize,
) -> Self {
let name = name.into();
validate_channel_name(&name);
self.assert_no_duplicate_channel(&name);
self.channel_defs
.push((name, ChannelKind::Reliable(max_bytes)));
self
}
pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
let commands = self.commands;
let handler_defs = self.handler_defs;
let channel_defs = self.channel_defs;
TauriPluginBuilder::<R>::new("conduit")
.register_asynchronous_uri_scheme_protocol("conduit", move |ctx, request, responder| {
if request.method() == "OPTIONS" {
let resp = http::Response::builder()
.status(204)
.header("Access-Control-Allow-Origin", "*")
.header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
.header(
"Access-Control-Allow-Headers",
"Content-Type, X-Conduit-Key, X-Conduit-Webview",
)
.header("Access-Control-Max-Age", "86400")
.body(Vec::new())
.expect("preflight response must not fail");
responder.respond(resp);
return;
}
let state: tauri::State<'_, PluginState<R>> = ctx.app_handle().state();
let path = request.uri().path();
let segments: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
if segments.len() != 2 {
responder.respond(make_error_response(
404,
"not found: expected /invoke/<cmd> or /drain/<channel>",
));
return;
}
let key = match request.headers().get("X-Conduit-Key") {
Some(v) => match v.to_str() {
Ok(s) => s,
Err(_) => {
responder
.respond(make_error_response(401, "invalid invoke key header"));
return;
}
},
None => {
responder.respond(make_error_response(401, "missing invoke key"));
return;
}
};
if !state.validate_invoke_key(key) {
responder.respond(make_error_response(403, "invalid invoke key"));
return;
}
let action = segments[0];
let raw_target = segments[1];
let target = percent_decode(raw_target);
if target.contains('/') {
responder.respond(make_error_response(400, "invalid command name"));
return;
}
match action {
"invoke" => {
let body = request.body().to_vec();
if let Some(handler) = state.handlers.get(&*target) {
let handler = Arc::clone(handler);
let webview_label = request
.headers()
.get("X-Conduit-Webview")
.and_then(|v| v.to_str().ok())
.filter(|s| {
!s.is_empty()
&& s.len() <= 128
&& s.bytes().all(|b| {
b.is_ascii_alphanumeric() || b == b'_' || b == b'-'
})
})
.map(|s| s.to_string());
let app_handle_arc: Arc<dyn std::any::Any + Send + Sync> =
state.app_handle_arc.clone();
let handler_ctx = conduit_core::HandlerContext::new(
app_handle_arc,
webview_label,
);
let ctx_any: Arc<dyn std::any::Any + Send + Sync> =
Arc::new(handler_ctx);
let result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
handler.call(body, ctx_any)
}));
match result {
Ok(HandlerResponse::Sync(Ok(bytes))) => {
responder.respond(make_response(
200,
"application/octet-stream",
bytes,
));
}
Ok(HandlerResponse::Sync(Err(e))) => {
let status = error_to_status(&e);
responder
.respond(make_error_response(status, &sanitize_error(&e)));
}
Ok(HandlerResponse::Async(future)) => {
tauri::async_runtime::spawn(async move {
let result = std::panic::AssertUnwindSafe(future)
.catch_unwind()
.await;
match result {
Ok(Ok(bytes)) => {
responder.respond(make_response(
200,
"application/octet-stream",
bytes,
));
}
Ok(Err(e)) => {
let status = error_to_status(&e);
responder.respond(make_error_response(
status,
&sanitize_error(&e),
));
}
Err(_) => {
responder.respond(make_error_response(
500,
"handler panicked",
));
}
}
});
}
Err(_) => {
responder.respond(make_error_response(500, "handler panicked"));
}
}
} else {
let dispatch = Arc::clone(&state.dispatch);
let app_handle_ref = &state.app_handle;
let result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
dispatch.call_with_context(&target, body, app_handle_ref)
}));
match result {
Ok(Ok(bytes)) => {
responder.respond(make_response(
200,
"application/octet-stream",
bytes,
));
}
Ok(Err(e)) => {
let status = error_to_status(&e);
responder
.respond(make_error_response(status, &sanitize_error(&e)));
}
Err(_) => {
responder.respond(make_error_response(500, "handler panicked"));
}
}
}
}
"drain" => match state.channel(&target) {
Some(ch) => {
let blob = ch.drain_all();
responder.respond(make_response(200, "application/octet-stream", blob));
}
None => {
responder.respond(make_error_response(
404,
&format!("unknown channel: {}", sanitize_name(&target)),
));
}
},
_ => {
responder.respond(make_error_response(
404,
"not found: expected /invoke/<cmd> or /drain/<channel>",
));
}
}
})
.invoke_handler(tauri::generate_handler![
conduit_bootstrap,
conduit_subscribe,
])
.setup(move |app, _api| {
let dispatch = Arc::new(Router::new());
for register_fn in commands {
register_fn(&dispatch);
}
let mut handler_map = HashMap::new();
for (name, handler) in handler_defs {
if dispatch.has(&name) {
#[cfg(debug_assertions)]
eprintln!(
"conduit: warning: handler '{name}' shadows a Router command \
with the same name — the #[command] handler takes priority"
);
}
handler_map.insert(name, handler);
}
let handlers = Arc::new(handler_map);
let mut channels = HashMap::new();
for (name, kind) in channel_defs {
let buf = match kind {
ChannelKind::Lossy(cap) => ChannelBuffer::Lossy(RingBuffer::new(cap)),
ChannelKind::Reliable(max_bytes) => {
ChannelBuffer::Reliable(Queue::new(max_bytes))
}
};
channels.insert(name, Arc::new(buf));
}
let invoke_key_bytes = generate_invoke_key_bytes();
let invoke_key = hex_encode(&invoke_key_bytes);
let app_handle = app.app_handle().clone();
let app_handle_arc = Arc::new(app_handle.clone());
let state = PluginState {
dispatch,
handlers,
channels,
app_handle,
app_handle_arc,
invoke_key,
invoke_key_bytes,
};
app.manage(state);
Ok(())
})
.build()
}
}
impl Default for PluginBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn init() -> PluginBuilder {
PluginBuilder::new()
}
fn error_to_status(e: &conduit_core::Error) -> u16 {
match e {
conduit_core::Error::UnknownCommand(_) => 404,
conduit_core::Error::UnknownChannel(_) => 404,
conduit_core::Error::AuthFailed => 403,
conduit_core::Error::DecodeFailed => 400,
conduit_core::Error::PayloadTooLarge(_) => 413,
conduit_core::Error::Handler(_) => 500,
conduit_core::Error::Serialize(_) => 500,
conduit_core::Error::ChannelFull => 500,
}
}
fn sanitize_name(name: &str) -> String {
let truncated = if name.len() > 64 {
let mut end = 64;
while end > 0 && !name.is_char_boundary(end) {
end -= 1;
}
&name[..end]
} else {
name
};
truncated.chars().filter(|c| !c.is_control()).collect()
}
fn sanitize_error(e: &conduit_core::Error) -> String {
match e {
conduit_core::Error::UnknownCommand(name) => {
format!("unknown command: {}", sanitize_name(name))
}
conduit_core::Error::UnknownChannel(name) => {
format!("unknown channel: {}", sanitize_name(name))
}
other => other.to_string(),
}
}
fn percent_decode(input: &str) -> std::borrow::Cow<'_, str> {
if !input.as_bytes().contains(&b'%') {
return std::borrow::Cow::Borrowed(input);
}
let mut result = Vec::with_capacity(input.len());
let bytes = input.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
if let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
result.push(hi << 4 | lo);
i += 3;
continue;
}
}
result.push(bytes[i]);
i += 1;
}
std::borrow::Cow::Owned(String::from_utf8_lossy(&result).into_owned())
}
fn hex_val(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
fn generate_invoke_key_bytes() -> [u8; 32] {
let mut bytes = [0u8; 32];
getrandom::fill(&mut bytes).expect("conduit: failed to generate invoke key");
bytes
}
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut hex = String::with_capacity(bytes.len() * 2);
for &b in bytes {
hex.push(HEX[(b >> 4) as usize] as char);
hex.push(HEX[(b & 0x0f) as usize] as char);
}
hex
}
#[cfg(test)]
fn hex_decode(hex: &str) -> Option<Vec<u8>> {
if hex.len() % 2 != 0 {
return None;
}
let mut bytes = Vec::with_capacity(hex.len() / 2);
for chunk in hex.as_bytes().chunks(2) {
let hi = hex_digit(chunk[0])?;
let lo = hex_digit(chunk[1])?;
bytes.push((hi << 4) | lo);
}
Some(bytes)
}
#[cfg(test)]
fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
fn validate_invoke_key_ct(expected: &[u8; 32], candidate: &str) -> bool {
let candidate_bytes = candidate.as_bytes();
if candidate_bytes.len() != 64 {
return false;
}
let mut decoded = [0u8; 32];
let mut all_valid = 1u8;
for i in 0..32 {
let (hi_val, hi_ok) = hex_digit_ct(candidate_bytes[i * 2]);
let (lo_val, lo_ok) = hex_digit_ct(candidate_bytes[i * 2 + 1]);
decoded[i] = (hi_val << 4) | lo_val;
all_valid &= hi_ok & lo_ok;
}
let cmp_ok: bool = expected.ct_eq(&decoded).into();
(all_valid == 1) & cmp_ok
}
fn hex_digit_ct(b: u8) -> (u8, u8) {
let b = b as i16;
let d = b.wrapping_sub(0x30); let digit_mask = ((!d) & (d.wrapping_sub(10))) >> 15;
let digit_mask = (digit_mask & 1) as u8;
let l = b.wrapping_sub(0x61); let lower_mask = ((!l) & (l.wrapping_sub(6))) >> 15;
let lower_mask = (lower_mask & 1) as u8;
let u = b.wrapping_sub(0x41); let upper_mask = ((!u) & (u.wrapping_sub(6))) >> 15;
let upper_mask = (upper_mask & 1) as u8;
let val = ((d as u8 & 0x0f) & digit_mask.wrapping_neg())
.wrapping_add((l as u8).wrapping_add(10) & lower_mask.wrapping_neg())
.wrapping_add((u as u8).wrapping_add(10) & upper_mask.wrapping_neg());
let valid = digit_mask | lower_mask | upper_mask;
(val, valid)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hex_encode_roundtrip() {
assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
assert_eq!(hex_encode(&[]), "");
assert_eq!(hex_encode(&[0x00, 0xff]), "00ff");
}
#[test]
fn hex_decode_valid() {
assert_eq!(hex_decode("deadbeef"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
assert_eq!(hex_decode(""), Some(vec![]));
assert_eq!(hex_decode("00ff"), Some(vec![0x00, 0xff]));
}
#[test]
fn hex_decode_uppercase() {
assert_eq!(hex_decode("DEADBEEF"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
assert_eq!(hex_decode("DeAdBeEf"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
}
#[test]
fn hex_decode_odd_length() {
assert_eq!(hex_decode("abc"), None);
assert_eq!(hex_decode("a"), None);
}
#[test]
fn hex_decode_invalid_chars() {
assert_eq!(hex_decode("zz"), None);
assert_eq!(hex_decode("gg"), None);
assert_eq!(hex_decode("0x"), None);
}
#[test]
fn hex_roundtrip_32_bytes() {
let original = generate_invoke_key_bytes();
let encoded = hex_encode(&original);
assert_eq!(encoded.len(), 64);
let decoded = hex_decode(&encoded).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn hex_digit_ct_valid_chars() {
for b in b'0'..=b'9' {
let (val, valid) = hex_digit_ct(b);
assert_eq!(valid, 1, "digit {b} should be valid");
assert_eq!(val, b - b'0');
}
for b in b'a'..=b'f' {
let (val, valid) = hex_digit_ct(b);
assert_eq!(valid, 1, "lower {b} should be valid");
assert_eq!(val, b - b'a' + 10);
}
for b in b'A'..=b'F' {
let (val, valid) = hex_digit_ct(b);
assert_eq!(valid, 1, "upper {b} should be valid");
assert_eq!(val, b - b'A' + 10);
}
}
#[test]
fn hex_digit_ct_invalid_chars() {
for &b in &[b'g', b'z', b'G', b'Z', b' ', b'\0', b'/', b':', b'@', b'`'] {
let (_val, valid) = hex_digit_ct(b);
assert_eq!(valid, 0, "char {b} should be invalid");
}
}
#[test]
fn hex_digit_ct_matches_hex_digit() {
for b in 0..=255u8 {
let ct_result = hex_digit_ct(b);
let std_result = hex_digit(b);
match std_result {
Some(v) => {
assert_eq!(ct_result.1, 1, "mismatch at {b}: ct says invalid");
assert_eq!(ct_result.0, v, "value mismatch at {b}");
}
None => {
assert_eq!(ct_result.1, 0, "mismatch at {b}: ct says valid");
}
}
}
}
#[test]
fn make_response_200() {
let resp = make_response(200, "application/octet-stream", b"hello".to_vec());
assert_eq!(resp.status(), 200);
assert_eq!(resp.body(), b"hello");
}
#[test]
fn make_response_404() {
let resp = make_response(404, "text/plain", b"not found".to_vec());
assert_eq!(resp.status(), 404);
assert_eq!(resp.body(), b"not found");
}
#[command]
fn with_state(state: tauri::State<'_, String>, name: String) -> String {
format!("{}: {name}", state.as_str())
}
#[test]
fn state_injection_wrong_context_returns_error() {
use conduit_core::ConduitHandler;
use conduit_derive::handler;
let payload = serde_json::to_vec(&serde_json::json!({ "name": "test" })).unwrap();
let wrong_ctx: Arc<dyn std::any::Any + Send + Sync> = Arc::new(());
match handler!(with_state).call(payload, wrong_ctx) {
conduit_core::HandlerResponse::Sync(Err(conduit_core::Error::Handler(msg))) => {
assert!(
msg.contains("handler context must be HandlerContext"),
"unexpected error message: {msg}"
);
}
_ => panic!("expected Sync(Err(Handler))"),
}
}
#[test]
fn original_state_function_preserved() {
let _fn_ref: fn(tauri::State<'_, String>, String) -> String = with_state;
}
#[test]
fn validate_invoke_key_correct() {
let key = [0xab_u8; 32];
let hex = hex_encode(&key);
assert!(validate_invoke_key_ct(&key, &hex));
}
#[test]
fn validate_invoke_key_wrong_key() {
let key = [0xab_u8; 32];
let wrong = hex_encode(&[0x00_u8; 32]);
assert!(!validate_invoke_key_ct(&key, &wrong));
}
#[test]
fn validate_invoke_key_wrong_length() {
let key = [0xab_u8; 32];
assert!(!validate_invoke_key_ct(&key, "abcdef"));
assert!(!validate_invoke_key_ct(&key, ""));
assert!(!validate_invoke_key_ct(&key, &"a".repeat(63)));
assert!(!validate_invoke_key_ct(&key, &"a".repeat(65)));
}
#[test]
fn validate_invoke_key_invalid_hex() {
let key = [0xab_u8; 32];
assert!(!validate_invoke_key_ct(&key, &"zz".repeat(32)));
assert!(!validate_invoke_key_ct(&key, &"gg".repeat(32)));
}
#[test]
fn validate_invoke_key_uppercase_accepted() {
let key = [0xab_u8; 32];
let hex = hex_encode(&key);
assert!(validate_invoke_key_ct(&key, &hex.to_uppercase()));
}
#[test]
fn validate_invoke_key_random_roundtrip() {
let key = generate_invoke_key_bytes();
let hex = hex_encode(&key);
assert!(validate_invoke_key_ct(&key, &hex));
}
#[test]
fn make_error_response_json_format() {
let resp = make_error_response(500, "something failed");
assert_eq!(resp.status(), 500);
let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
assert_eq!(body["error"], "something failed");
}
#[test]
fn make_error_response_escapes_special_chars() {
let resp = make_error_response(400, r#"bad "input" with \ slash"#);
let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
assert_eq!(body["error"], r#"bad "input" with \ slash"#);
}
#[test]
fn percent_decode_no_encoding() {
assert_eq!(percent_decode("hello"), "hello");
assert_eq!(percent_decode("foo-bar_baz"), "foo-bar_baz");
}
#[test]
fn percent_decode_basic() {
assert_eq!(percent_decode("hello%20world"), "hello world");
assert_eq!(percent_decode("%2F"), "/");
assert_eq!(percent_decode("%2f"), "/");
}
#[test]
fn percent_decode_multiple() {
assert_eq!(percent_decode("a%20b%20c"), "a b c");
assert_eq!(percent_decode("%41%42%43"), "ABC");
}
#[test]
fn percent_decode_incomplete_sequence() {
assert_eq!(percent_decode("hello%2"), "hello%2");
assert_eq!(percent_decode("hello%"), "hello%");
}
#[test]
fn percent_decode_invalid_hex() {
assert_eq!(percent_decode("hello%GG"), "hello%GG");
assert_eq!(percent_decode("%ZZ"), "%ZZ");
}
#[test]
fn percent_decode_empty() {
assert_eq!(percent_decode(""), "");
}
#[test]
fn sanitize_name_short() {
assert_eq!(sanitize_name("hello"), "hello");
}
#[test]
fn sanitize_name_truncates_long() {
let long = "a".repeat(100);
assert_eq!(sanitize_name(&long).len(), 64);
}
#[test]
fn sanitize_name_strips_control_chars() {
assert_eq!(sanitize_name("hello\x00world"), "helloworld");
assert_eq!(sanitize_name("foo\nbar\rbaz"), "foobarbaz");
}
#[test]
fn sanitize_name_multibyte_utf8() {
let name = format!("{}{}", "a".repeat(63), "é");
assert_eq!(name.len(), 65);
let sanitized = sanitize_name(&name);
assert_eq!(sanitized, "a".repeat(63));
let name = format!("{}🦀", "a".repeat(62)); assert_eq!(name.len(), 66);
let sanitized = sanitize_name(&name);
assert_eq!(sanitized, "a".repeat(62));
let name = "a".repeat(64);
assert_eq!(sanitize_name(&name), "a".repeat(64));
}
#[test]
fn error_to_status_mapping() {
use conduit_core::Error;
assert_eq!(error_to_status(&Error::UnknownCommand("x".into())), 404);
assert_eq!(error_to_status(&Error::UnknownChannel("x".into())), 404);
assert_eq!(error_to_status(&Error::AuthFailed), 403);
assert_eq!(error_to_status(&Error::DecodeFailed), 400);
assert_eq!(error_to_status(&Error::PayloadTooLarge(999)), 413);
assert_eq!(error_to_status(&Error::Handler("x".into())), 500);
assert_eq!(error_to_status(&Error::ChannelFull), 500);
}
#[test]
fn validate_channel_name_valid() {
validate_channel_name("telemetry");
validate_channel_name("my-channel");
validate_channel_name("my_channel");
validate_channel_name("Channel123");
validate_channel_name("a");
}
#[test]
#[should_panic(expected = "invalid channel name")]
fn validate_channel_name_empty() {
validate_channel_name("");
}
#[test]
#[should_panic(expected = "invalid channel name")]
fn validate_channel_name_spaces() {
validate_channel_name("my channel");
}
#[test]
#[should_panic(expected = "invalid channel name")]
fn validate_channel_name_special_chars() {
validate_channel_name("my.channel");
}
#[test]
#[should_panic(expected = "duplicate channel name")]
fn duplicate_channel_panics() {
PluginBuilder::new()
.channel("telemetry")
.channel("telemetry");
}
#[test]
#[should_panic(expected = "duplicate channel name")]
fn duplicate_channel_different_kinds_panics() {
PluginBuilder::new().channel("data").channel_ordered("data");
}
}