Skip to main content

iii_sdk/
iii.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::{
4        Arc, Mutex, MutexGuard,
5        atomic::{AtomicBool, AtomicUsize, Ordering},
6    },
7    time::Duration,
8};
9
10/// Extension trait for Mutex that recovers from poisoning instead of panicking.
11/// This is safe when the protected data is still valid after a panic in another thread.
12trait MutexExt<T> {
13    fn lock_or_recover(&self) -> MutexGuard<'_, T>;
14}
15
16impl<T> MutexExt<T> for Mutex<T> {
17    fn lock_or_recover(&self) -> MutexGuard<'_, T> {
18        self.lock().unwrap_or_else(|e| e.into_inner())
19    }
20}
21
22use futures_util::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use tokio::{
26    sync::{mpsc, oneshot},
27    time::sleep,
28};
29use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
30use uuid::Uuid;
31
32const SDK_VERSION: &str = env!("CARGO_PKG_VERSION");
33
34use crate::{
35    channels::{ChannelReader, ChannelWriter, StreamChannelRef},
36    error::IIIError,
37    protocol::{
38        ErrorBody, HttpInvocationConfig, Message, RegisterFunctionMessage, RegisterServiceMessage,
39        RegisterTriggerInput, RegisterTriggerMessage, RegisterTriggerTypeMessage, TriggerAction,
40        TriggerRequest, UnregisterTriggerMessage, UnregisterTriggerTypeMessage,
41    },
42    triggers::{Trigger, TriggerConfig, TriggerHandler},
43    types::{Channel, RemoteFunctionData, RemoteFunctionHandler, RemoteTriggerTypeData},
44};
45
46use crate::telemetry;
47use crate::telemetry::types::OtelConfig;
48
49const DEFAULT_TIMEOUT_MS: u64 = 30_000;
50
51/// Worker information returned by `engine::workers::list`
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct WorkerInfo {
54    pub id: String,
55    pub name: Option<String>,
56    pub runtime: Option<String>,
57    pub version: Option<String>,
58    pub os: Option<String>,
59    pub ip_address: Option<String>,
60    pub status: String,
61    pub connected_at_ms: u64,
62    pub function_count: usize,
63    pub functions: Vec<String>,
64    pub active_invocations: usize,
65}
66
67/// Function information returned by `engine::functions::list`
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct FunctionInfo {
70    pub function_id: String,
71    pub description: Option<String>,
72    pub request_format: Option<Value>,
73    pub response_format: Option<Value>,
74    pub metadata: Option<Value>,
75}
76
77/// Trigger information returned by `engine::triggers::list`
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct TriggerInfo {
80    pub id: String,
81    pub trigger_type: String,
82    pub function_id: String,
83    pub config: Value,
84    pub metadata: Option<Value>,
85}
86
87/// Trigger type information returned by `engine::trigger-types::list`
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct TriggerTypeInfo {
90    pub id: String,
91    pub description: String,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub trigger_request_format: Option<Value>,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub call_request_format: Option<Value>,
96}
97
98/// Builder for registering a custom trigger type with optional format schemas.
99///
100/// Type parameters:
101/// - `C` tracks the trigger registration type (set via `.trigger_request_format::<T>()`)
102/// - `R` tracks the call request type (set via `.call_request_format::<T>()`)
103///
104/// Both default to `Value` (untyped) and change when the respective builder
105/// method is called. This allows [`III::register_trigger_type`] to return a
106/// [`TriggerTypeRef<C, R>`] with compile-time safety for both config and
107/// function input types.
108pub struct RegisterTriggerType<H, C = Value, R = Value> {
109    id: String,
110    description: String,
111    handler: H,
112    trigger_request_format: Option<Value>,
113    call_request_format: Option<Value>,
114    _phantom: std::marker::PhantomData<(C, R)>,
115}
116
117impl<H: TriggerHandler> RegisterTriggerType<H> {
118    pub fn new(id: impl Into<String>, description: impl Into<String>, handler: H) -> Self {
119        Self {
120            id: id.into(),
121            description: description.into(),
122            handler,
123            trigger_request_format: None,
124            call_request_format: None,
125            _phantom: std::marker::PhantomData,
126        }
127    }
128}
129
130impl<H: TriggerHandler, C, R> RegisterTriggerType<H, C, R> {
131    /// Set the trigger request format schema from a type.
132    /// Changes `C`, enabling compile-time validation on
133    /// [`TriggerTypeRef::register_trigger`].
134    pub fn trigger_request_format<T: schemars::JsonSchema + Serialize>(
135        self,
136    ) -> RegisterTriggerType<H, T, R> {
137        RegisterTriggerType {
138            id: self.id,
139            description: self.description,
140            handler: self.handler,
141            trigger_request_format: json_schema_for::<T>(),
142            call_request_format: self.call_request_format,
143            _phantom: std::marker::PhantomData,
144        }
145    }
146
147    /// Set the call request format schema from a type.
148    /// Changes `R`, enabling compile-time validation on
149    /// [`TriggerTypeRef::register_function`].
150    pub fn call_request_format<T: schemars::JsonSchema>(self) -> RegisterTriggerType<H, C, T> {
151        RegisterTriggerType {
152            id: self.id,
153            description: self.description,
154            handler: self.handler,
155            trigger_request_format: self.trigger_request_format,
156            call_request_format: json_schema_for::<T>(),
157            _phantom: std::marker::PhantomData,
158        }
159    }
160}
161
162/// Typed handle returned by [`III::register_trigger_type`].
163///
164/// Type parameters:
165/// - `C` — trigger registration type for [`register_trigger`](Self::register_trigger)
166/// - `R` — call request type for [`register_function`](Self::register_function)
167#[derive(Clone)]
168pub struct TriggerTypeRef<C = Value, R = Value> {
169    iii: III,
170    trigger_type_id: String,
171    _phantom: std::marker::PhantomData<(C, R)>,
172}
173
174impl<C: Serialize, R> TriggerTypeRef<C, R> {
175    /// Register a trigger with compile-time validated trigger config.
176    pub fn register_trigger(
177        &self,
178        function_id: impl Into<String>,
179        config: C,
180    ) -> Result<Trigger, IIIError> {
181        self.register_trigger_with_metadata(function_id, config, None)
182    }
183
184    /// Register a trigger with compile-time validated trigger config and optional metadata.
185    pub fn register_trigger_with_metadata(
186        &self,
187        function_id: impl Into<String>,
188        config: C,
189        metadata: Option<Value>,
190    ) -> Result<Trigger, IIIError> {
191        self.iii.register_trigger(RegisterTriggerInput {
192            trigger_type: self.trigger_type_id.clone(),
193            function_id: function_id.into(),
194            config: serde_json::to_value(config).map_err(|e| IIIError::Handler(e.to_string()))?,
195            metadata,
196        })
197    }
198}
199
200impl<C, R> TriggerTypeRef<C, R>
201where
202    R: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
203{
204    /// Register a sync function whose input type must match
205    /// the call request format `R`.
206    pub fn register_function<O, E, F>(&self, id: impl Into<String>, f: F) -> FunctionRef
207    where
208        O: Serialize + schemars::JsonSchema + Send + 'static,
209        E: std::fmt::Display + Send + 'static,
210        F: Fn(R) -> Result<O, E> + Send + Sync + 'static,
211    {
212        self.iii.register_function(RegisterFunction::new(id, f))
213    }
214
215    /// Register an async function whose input type must match
216    /// the call request format `R`.
217    pub fn register_function_async<O, E, F, Fut>(&self, id: impl Into<String>, f: F) -> FunctionRef
218    where
219        O: Serialize + schemars::JsonSchema + Send + 'static,
220        E: std::fmt::Display + Send + 'static,
221        F: Fn(R) -> Fut + Send + Sync + 'static,
222        Fut: std::future::Future<Output = Result<O, E>> + Send + 'static,
223    {
224        self.iii
225            .register_function(RegisterFunction::new_async(id, f))
226    }
227}
228
229/// Telemetry metadata provided by the SDK to the engine.
230#[derive(Debug, Clone, Serialize, Deserialize, Default)]
231pub struct WorkerTelemetryMeta {
232    #[serde(skip_serializing_if = "Option::is_none")]
233    pub language: Option<String>,
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub project_name: Option<String>,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub framework: Option<String>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub amplitude_api_key: Option<String>,
240}
241
242/// Worker metadata for auto-registration
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct WorkerMetadata {
245    pub runtime: String,
246    pub version: String,
247    pub name: String,
248    pub os: String,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub pid: Option<u32>,
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub telemetry: Option<WorkerTelemetryMeta>,
253}
254
255impl Default for WorkerMetadata {
256    fn default() -> Self {
257        let hostname = hostname::get()
258            .map(|h| h.to_string_lossy().to_string())
259            .unwrap_or_else(|_| "unknown".to_string());
260        let pid = std::process::id();
261        let os_info = format!(
262            "{} {} ({})",
263            std::env::consts::OS,
264            std::env::consts::ARCH,
265            std::env::consts::FAMILY
266        );
267
268        let language = std::env::var("LANG")
269            .or_else(|_| std::env::var("LC_ALL"))
270            .ok()
271            .filter(|s| !s.is_empty())
272            .map(|s| s.split('.').next().unwrap_or(&s).to_string());
273
274        Self {
275            runtime: "rust".to_string(),
276            version: SDK_VERSION.to_string(),
277            name: format!("{}:{}", hostname, pid),
278            os: os_info,
279            pid: Some(pid),
280            telemetry: Some(WorkerTelemetryMeta {
281                language,
282                ..Default::default()
283            }),
284        }
285    }
286}
287
288#[allow(clippy::large_enum_variant)]
289enum Outbound {
290    Message(Message),
291    Shutdown,
292}
293
294type PendingInvocation = oneshot::Sender<Result<Value, IIIError>>;
295
296// WebSocket transmitter type alias
297type WsTx = futures_util::stream::SplitSink<
298    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
299    WsMessage,
300>;
301
302/// Inject trace context headers for outbound messages.
303fn inject_trace_headers() -> (Option<String>, Option<String>) {
304    use crate::telemetry::context;
305    (context::inject_traceparent(), context::inject_baggage())
306}
307
308/// Connection state for the III WebSocket client
309#[derive(Debug, Clone, Copy, PartialEq, Eq)]
310pub enum IIIConnectionState {
311    Disconnected,
312    Connecting,
313    Connected,
314    Reconnecting,
315    Failed,
316}
317
318/// Callback function type for functions available events
319pub type FunctionsAvailableCallback = Arc<dyn Fn(Vec<FunctionInfo>) + Send + Sync>;
320
321#[derive(Clone)]
322pub struct FunctionRef {
323    pub id: String,
324    unregister_fn: Arc<dyn Fn() + Send + Sync>,
325}
326
327impl FunctionRef {
328    pub fn unregister(&self) {
329        (self.unregister_fn)();
330    }
331}
332
333pub trait IntoFunctionHandler {
334    fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler>;
335}
336
337/// Trait for types that can be passed to [`III::register_function`].
338///
339/// Implemented for:
340/// - [`RegisterFunction`] — the builder API (recommended)
341/// - `(RegisterFunctionMessage, H)` — the legacy tuple API
342pub trait IntoFunctionRegistration {
343    fn into_registration(self) -> (RegisterFunctionMessage, Option<RemoteFunctionHandler>);
344}
345
346impl IntoFunctionRegistration for RegisterFunction {
347    fn into_registration(self) -> (RegisterFunctionMessage, Option<RemoteFunctionHandler>) {
348        (self.message, Some(self.handler))
349    }
350}
351
352impl<H: IntoFunctionHandler> IntoFunctionRegistration for (RegisterFunctionMessage, H) {
353    fn into_registration(self) -> (RegisterFunctionMessage, Option<RemoteFunctionHandler>) {
354        let (mut message, handler) = self;
355        let handler = handler.into_parts(&mut message);
356        (message, handler)
357    }
358}
359
360impl IntoFunctionHandler for HttpInvocationConfig {
361    fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
362        message.invocation = Some(self);
363        None
364    }
365}
366
367impl<F, Fut> IntoFunctionHandler for F
368where
369    F: Fn(Value) -> Fut + Send + Sync + 'static,
370    Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
371{
372    fn into_parts(self, _message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
373        Some(Arc::new(move |input: Value| Box::pin(self(input))))
374    }
375}
376
377// =============================================================================
378// iii_fn — sync function wrapper
379// =============================================================================
380
381/// Wrapper for registering sync functions as III handlers via [`iii_fn`].
382///
383/// Created by [`iii_fn`]. Stores a pre-erased handler so that a single
384/// [`IntoFunctionHandler`] impl covers all supported arities.
385pub struct IIIFn<F = ()> {
386    handler: RemoteFunctionHandler,
387    request_format: Option<Value>,
388    response_format: Option<Value>,
389    _marker: std::marker::PhantomData<F>,
390}
391
392fn json_schema_for<T: schemars::JsonSchema>() -> Option<Value> {
393    serde_json::to_value(
394        schemars::r#gen::SchemaSettings::draft07()
395            .into_generator()
396            .into_root_schema_for::<T>(),
397    )
398    .ok()
399}
400
401/// Helper trait used internally to convert a sync function into a
402/// [`RemoteFunctionHandler`].
403#[doc(hidden)]
404pub trait IntoSyncHandler<Marker>: Send + Sync + 'static {
405    fn into_handler(self) -> RemoteFunctionHandler;
406    fn request_format() -> Option<Value> {
407        None
408    }
409    fn response_format() -> Option<Value> {
410        None
411    }
412}
413
414// 1-arg sync — deserializes the entire JSON input as T
415impl<F, T, R, E> IntoSyncHandler<(T, R, E)> for F
416where
417    F: Fn(T) -> Result<R, E> + Send + Sync + 'static,
418    T: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
419    R: serde::Serialize + schemars::JsonSchema + Send + 'static,
420    E: std::fmt::Display + Send + 'static,
421{
422    fn into_handler(self) -> RemoteFunctionHandler {
423        Arc::new(move |input: Value| {
424            let output = serde_json::from_value::<T>(input)
425                .map_err(|e| IIIError::Handler(e.to_string()))
426                .and_then(|arg| (self)(arg).map_err(|e| IIIError::Handler(e.to_string())))
427                .and_then(|val| {
428                    serde_json::to_value(&val).map_err(|e| IIIError::Handler(e.to_string()))
429                });
430            Box::pin(async move { output })
431        })
432    }
433
434    fn request_format() -> Option<Value> {
435        json_schema_for::<T>()
436    }
437
438    fn response_format() -> Option<Value> {
439        json_schema_for::<R>()
440    }
441}
442
443/// Wraps a **sync** function into an III-compatible handler.
444///
445/// The function must take a single argument implementing
446/// [`serde::de::DeserializeOwned`] and return `Result<R, E>`
447/// where `R: Serialize` and `E: Display`.
448///
449/// The entire JSON input is deserialized as the argument type.
450/// Use a `#[derive(Deserialize)]` struct for named JSON keys.
451///
452/// For async functions, use [`iii_async_fn`] instead.
453pub fn iii_fn<F, M>(f: F) -> IIIFn<F>
454where
455    F: IntoSyncHandler<M>,
456{
457    IIIFn {
458        request_format: F::request_format(),
459        response_format: F::response_format(),
460        handler: f.into_handler(),
461        _marker: std::marker::PhantomData,
462    }
463}
464
465impl<F> IntoFunctionHandler for IIIFn<F> {
466    fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
467        if message.request_format.is_none() {
468            message.request_format = self.request_format;
469        }
470        if message.response_format.is_none() {
471            message.response_format = self.response_format;
472        }
473        Some(self.handler)
474    }
475}
476
477// =============================================================================
478// iii_async_fn — async function wrapper
479// =============================================================================
480
481/// Wrapper for registering async functions as III handlers via [`iii_async_fn`].
482///
483/// Created by [`iii_async_fn`]. Stores a pre-erased handler so that a single
484/// [`IntoFunctionHandler`] impl covers all supported arities.
485pub struct IIIAsyncFn<F = ()> {
486    handler: RemoteFunctionHandler,
487    request_format: Option<Value>,
488    response_format: Option<Value>,
489    _marker: std::marker::PhantomData<F>,
490}
491
492/// Helper trait used internally to convert an async function into a
493/// [`RemoteFunctionHandler`].
494#[doc(hidden)]
495pub trait IntoAsyncHandler<Marker>: Send + Sync + 'static {
496    fn into_handler(self) -> RemoteFunctionHandler;
497    fn request_format() -> Option<Value> {
498        None
499    }
500    fn response_format() -> Option<Value> {
501        None
502    }
503}
504
505// 1-arg async — deserializes the entire JSON input as T
506impl<F, T, Fut, R, E> IntoAsyncHandler<(T, Fut, R, E)> for F
507where
508    F: Fn(T) -> Fut + Send + Sync + 'static,
509    T: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
510    Fut: std::future::Future<Output = Result<R, E>> + Send + 'static,
511    R: serde::Serialize + schemars::JsonSchema + Send + 'static,
512    E: std::fmt::Display + Send + 'static,
513{
514    fn into_handler(self) -> RemoteFunctionHandler {
515        Arc::new(
516            move |input: Value| -> std::pin::Pin<
517                Box<dyn std::future::Future<Output = Result<Value, IIIError>> + Send>,
518            > {
519                match serde_json::from_value::<T>(input) {
520                    Ok(arg) => {
521                        let fut = (self)(arg);
522                        Box::pin(async move {
523                            fut.await
524                                .map_err(|e| IIIError::Handler(e.to_string()))
525                                .and_then(|val| {
526                                    serde_json::to_value(&val)
527                                        .map_err(|e| IIIError::Handler(e.to_string()))
528                                })
529                        })
530                    }
531                    Err(e) => Box::pin(async move { Err(IIIError::Handler(e.to_string())) }),
532                }
533            },
534        )
535    }
536
537    fn request_format() -> Option<Value> {
538        json_schema_for::<T>()
539    }
540
541    fn response_format() -> Option<Value> {
542        json_schema_for::<R>()
543    }
544}
545
546/// Wraps an **async** function into an III-compatible handler.
547///
548/// The function must take a single argument implementing
549/// [`serde::de::DeserializeOwned`] and return
550/// `impl Future<Output = Result<R, E>>` where `R: Serialize` and `E: Display`.
551pub fn iii_async_fn<F, M>(f: F) -> IIIAsyncFn<F>
552where
553    F: IntoAsyncHandler<M>,
554{
555    IIIAsyncFn {
556        request_format: F::request_format(),
557        response_format: F::response_format(),
558        handler: f.into_handler(),
559        _marker: std::marker::PhantomData,
560    }
561}
562
563impl<F> IntoFunctionHandler for IIIAsyncFn<F> {
564    fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
565        if message.request_format.is_none() {
566            message.request_format = self.request_format;
567        }
568        if message.response_format.is_none() {
569            message.response_format = self.response_format;
570        }
571        Some(self.handler)
572    }
573}
574
575// =============================================================================
576// RegisterFunction — one-step registration builder
577// =============================================================================
578
579/// One-step function registration combining ID, handler, and auto-generated schemas.
580///
581/// Use [`RegisterFunction::new`] for sync functions or [`RegisterFunction::new_async`]
582/// for async functions, then register with [`III::register`].
583pub struct RegisterFunction {
584    message: RegisterFunctionMessage,
585    handler: RemoteFunctionHandler,
586}
587
588impl RegisterFunction {
589    /// Create a registration for a **sync** function.
590    pub fn new<F, M>(id: impl Into<String>, f: F) -> Self
591    where
592        F: IntoSyncHandler<M>,
593    {
594        Self {
595            message: RegisterFunctionMessage {
596                id: id.into(),
597                description: None,
598                request_format: F::request_format(),
599                response_format: F::response_format(),
600                metadata: None,
601                invocation: None,
602            },
603            handler: f.into_handler(),
604        }
605    }
606
607    /// Create a registration for an **async** function.
608    pub fn new_async<F, M>(id: impl Into<String>, f: F) -> Self
609    where
610        F: IntoAsyncHandler<M>,
611    {
612        Self {
613            message: RegisterFunctionMessage {
614                id: id.into(),
615                description: None,
616                request_format: F::request_format(),
617                response_format: F::response_format(),
618                metadata: None,
619                invocation: None,
620            },
621            handler: f.into_handler(),
622        }
623    }
624
625    /// Set the function description.
626    pub fn description(mut self, desc: impl Into<String>) -> Self {
627        self.message.description = Some(desc.into());
628        self
629    }
630
631    /// Set function metadata.
632    pub fn metadata(mut self, meta: Value) -> Self {
633        self.message.metadata = Some(meta);
634        self
635    }
636
637    /// Get the auto-generated request format.
638    pub fn request_format(&self) -> Option<&Value> {
639        self.message.request_format.as_ref()
640    }
641
642    /// Get the auto-generated response format.
643    pub fn response_format(&self) -> Option<&Value> {
644        self.message.response_format.as_ref()
645    }
646}
647
648struct IIIInner {
649    address: String,
650    outbound: mpsc::UnboundedSender<Outbound>,
651    receiver: Mutex<Option<mpsc::UnboundedReceiver<Outbound>>>,
652    running: AtomicBool,
653    started: AtomicBool,
654    pending: Mutex<HashMap<Uuid, PendingInvocation>>,
655    functions: Mutex<HashMap<String, RemoteFunctionData>>,
656    trigger_types: Mutex<HashMap<String, RemoteTriggerTypeData>>,
657    triggers: Mutex<HashMap<String, RegisterTriggerMessage>>,
658    services: Mutex<HashMap<String, RegisterServiceMessage>>,
659    worker_metadata: Mutex<Option<WorkerMetadata>>,
660    connection_state: Mutex<IIIConnectionState>,
661    connection_thread: Mutex<Option<std::thread::JoinHandle<()>>>,
662    functions_available_callbacks: Mutex<HashMap<usize, FunctionsAvailableCallback>>,
663    functions_available_callback_counter: AtomicUsize,
664    functions_available_function_id: Mutex<Option<String>>,
665    functions_available_trigger: Mutex<Option<Trigger>>,
666    headers: Mutex<Option<HashMap<String, String>>>,
667    otel_config: Mutex<Option<OtelConfig>>,
668}
669
670/// WebSocket client for communication with the III Engine.
671///
672/// Create with [`register_worker`](crate::register_worker).
673#[derive(Clone)]
674pub struct III {
675    inner: Arc<IIIInner>,
676}
677
678/// Guard that unsubscribes from functions available events when dropped
679pub struct FunctionsAvailableGuard {
680    iii: III,
681    callback_id: usize,
682}
683
684impl Drop for FunctionsAvailableGuard {
685    fn drop(&mut self) {
686        let mut callbacks = self
687            .iii
688            .inner
689            .functions_available_callbacks
690            .lock_or_recover();
691        callbacks.remove(&self.callback_id);
692
693        if callbacks.is_empty() {
694            let mut trigger = self.iii.inner.functions_available_trigger.lock_or_recover();
695            if let Some(trigger) = trigger.take() {
696                trigger.unregister();
697            }
698        }
699    }
700}
701
702impl III {
703    /// Create a new III with default worker metadata (auto-detected runtime, os, hostname)
704    pub fn new(address: &str) -> Self {
705        Self::with_metadata(address, WorkerMetadata::default())
706    }
707
708    /// Create a new III with custom worker metadata
709    pub fn with_metadata(address: &str, metadata: WorkerMetadata) -> Self {
710        let (tx, rx) = mpsc::unbounded_channel();
711        let inner = IIIInner {
712            address: address.into(),
713            outbound: tx,
714            receiver: Mutex::new(Some(rx)),
715            running: AtomicBool::new(false),
716            started: AtomicBool::new(false),
717            pending: Mutex::new(HashMap::new()),
718            functions: Mutex::new(HashMap::new()),
719            trigger_types: Mutex::new(HashMap::new()),
720            triggers: Mutex::new(HashMap::new()),
721            services: Mutex::new(HashMap::new()),
722            worker_metadata: Mutex::new(Some(metadata)),
723            connection_state: Mutex::new(IIIConnectionState::Disconnected),
724            connection_thread: Mutex::new(None),
725            functions_available_callbacks: Mutex::new(HashMap::new()),
726            functions_available_callback_counter: AtomicUsize::new(0),
727            functions_available_function_id: Mutex::new(None),
728            functions_available_trigger: Mutex::new(None),
729            headers: Mutex::new(None),
730            otel_config: Mutex::new(None),
731        };
732        Self {
733            inner: Arc::new(inner),
734        }
735    }
736
737    /// Get the engine WebSocket address this client connects to.
738    pub fn address(&self) -> &str {
739        &self.inner.address
740    }
741
742    /// Set custom worker metadata (call before connect)
743    pub fn set_metadata(&self, metadata: WorkerMetadata) {
744        *self.inner.worker_metadata.lock_or_recover() = Some(metadata);
745    }
746
747    /// Set custom HTTP headers for the WebSocket handshake (call before connect).
748    pub fn set_headers(&self, headers: HashMap<String, String>) {
749        *self.inner.headers.lock_or_recover() = Some(headers);
750    }
751
752    /// Set OpenTelemetry configuration (call before connect)
753    pub fn set_otel_config(&self, config: OtelConfig) {
754        *self.inner.otel_config.lock_or_recover() = Some(config);
755    }
756
757    pub(crate) fn connect(&self) {
758        if self.inner.started.swap(true, Ordering::SeqCst) {
759            return;
760        }
761
762        let receiver = self.inner.receiver.lock_or_recover().take();
763        let Some(rx) = receiver else { return };
764
765        self.inner.running.store(true, Ordering::SeqCst);
766
767        let iii = self.clone();
768
769        let otel_config = {
770            let mut config = self
771                .inner
772                .otel_config
773                .lock_or_recover()
774                .take()
775                .unwrap_or_default();
776            if config.engine_ws_url.is_none() {
777                config.engine_ws_url = Some(self.inner.address.clone());
778            }
779            config
780        };
781
782        // Spawn a dedicated OS thread with its own tokio runtime so
783        // the connection loop is independent of the caller's runtime.
784        // In Rust, a spawned thread does not keep the process alive on its own;
785        // call shutdown() to signal the thread and join connection_thread so
786        // run_connection() can exit cleanly before main() returns.
787        let handle = std::thread::Builder::new()
788            .name("iii-connection".into())
789            .spawn(move || {
790                let rt = tokio::runtime::Builder::new_current_thread()
791                    .enable_all()
792                    .build()
793                    .expect("failed to create iii connection runtime");
794
795                rt.block_on(async move {
796                    let otel_active = telemetry::init_otel(otel_config).await;
797
798                    iii.run_connection(rx).await;
799
800                    if otel_active {
801                        telemetry::shutdown_otel().await;
802                    }
803                });
804            })
805            .expect("failed to spawn iii connection thread");
806
807        *self.inner.connection_thread.lock_or_recover() = Some(handle);
808    }
809
810    /// Shutdown the III client and wait for the connection thread to finish.
811    ///
812    /// This stops the connection loop, sends a shutdown signal, and joins
813    /// the background connection thread. Telemetry is flushed inside the
814    /// connection thread before it exits.
815    pub fn shutdown(&self) {
816        self.inner.running.store(false, Ordering::SeqCst);
817        let _ = self.inner.outbound.send(Outbound::Shutdown);
818        self.set_connection_state(IIIConnectionState::Disconnected);
819
820        if let Some(handle) = self.inner.connection_thread.lock_or_recover().take() {
821            let _ = handle.join();
822        }
823    }
824
825    /// Shutdown the III client.
826    ///
827    /// This stops the connection loop and sends a shutdown signal, but it
828    /// does not join `connection_thread`.
829    ///
830    /// Unlike [`shutdown`](Self::shutdown), this method does **not** block
831    /// to wait for `run_connection()` to finish, making it safe to call from
832    /// an async context without stalling the executor.
833    /// `telemetry::shutdown_otel()` still runs inside the connection thread
834    /// after `run_connection()` returns, so it may not complete unless
835    /// [`shutdown`](Self::shutdown) is used to join the thread.
836    pub async fn shutdown_async(&self) {
837        self.inner.running.store(false, Ordering::SeqCst);
838        let _ = self.inner.outbound.send(Outbound::Shutdown);
839        self.set_connection_state(IIIConnectionState::Disconnected);
840    }
841
842    fn register_function_inner(
843        &self,
844        message: RegisterFunctionMessage,
845        handler: Option<RemoteFunctionHandler>,
846    ) -> FunctionRef {
847        let id = message.id.clone();
848        if id.trim().is_empty() {
849            panic!("id is required");
850        }
851        let data = RemoteFunctionData {
852            message: message.clone(),
853            handler,
854        };
855        let mut funcs = self.inner.functions.lock_or_recover();
856        match funcs.entry(id.clone()) {
857            std::collections::hash_map::Entry::Occupied(_) => {
858                panic!("function id '{}' already registered", id);
859            }
860            std::collections::hash_map::Entry::Vacant(entry) => {
861                entry.insert(data);
862            }
863        }
864        drop(funcs);
865        let _ = self.send_message(message.to_message());
866
867        let iii = self.clone();
868        let unregister_id = id.clone();
869        let unregister_fn = Arc::new(move || {
870            let _ = iii.inner.functions.lock_or_recover().remove(&unregister_id);
871            let _ = iii.send_message(Message::UnregisterFunction {
872                id: unregister_id.clone(),
873            });
874        });
875
876        FunctionRef { id, unregister_fn }
877    }
878
879    /// Register a function with the engine.
880    ///
881    /// Pass a closure/async fn for local execution, or an [`HttpInvocationConfig`]
882    /// for HTTP-invoked functions (Lambda, Cloudflare Workers, etc.).
883    ///
884    /// # Arguments
885    /// * `message` - Function registration message with id and optional metadata.
886    /// * `handler` - Async handler or HTTP invocation config.
887    ///
888    /// # Panics
889    /// Panics if `id` is empty or already registered.
890    ///
891    /// # Examples
892    /// ```rust,no_run
893    /// use iii_sdk::{register_worker, InitOptions, RegisterFunction};
894    /// use serde::Deserialize;
895    /// use schemars::JsonSchema;
896    ///
897    /// #[derive(Deserialize, JsonSchema)]
898    /// struct Input { name: String }
899    /// fn greet(input: Input) -> Result<String, String> {
900    ///     Ok(format!("Hello, {}!", input.name))
901    /// }
902    ///
903    /// let iii = register_worker("ws://localhost:49134", InitOptions::default());
904    /// iii.register_function(RegisterFunction::new("greet", greet));
905    /// ```
906    ///
907    /// Also accepts a two-argument form via [`register_function_with`](III::register_function_with):
908    /// ```rust,no_run
909    /// # use iii_sdk::{register_worker, InitOptions, RegisterFunctionMessage};
910    /// # use serde_json::{json, Value};
911    /// # let iii = register_worker("ws://localhost:49134", InitOptions::default());
912    /// iii.register_function_with(
913    ///     RegisterFunctionMessage::with_id("echo".to_string()),
914    ///     |input: Value| async move { Ok(json!({"echo": input})) },
915    /// );
916    /// ```
917    pub fn register_function<R: IntoFunctionRegistration>(&self, registration: R) -> FunctionRef {
918        let (message, handler) = registration.into_registration();
919        self.register_function_inner(message, handler)
920    }
921
922    /// Register a function with a message and handler directly.
923    pub fn register_function_with<H: IntoFunctionHandler>(
924        &self,
925        mut message: RegisterFunctionMessage,
926        handler: H,
927    ) -> FunctionRef {
928        let handler = handler.into_parts(&mut message);
929        self.register_function_inner(message, handler)
930    }
931
932    /// Register a service with the engine.
933    ///
934    /// # Arguments
935    /// * `message` - Service registration message with id, name, and optional metadata.
936    pub fn register_service(&self, message: RegisterServiceMessage) {
937        self.inner
938            .services
939            .lock_or_recover()
940            .insert(message.id.clone(), message.clone());
941        let _ = self.send_message(message.to_message());
942    }
943
944    /// Register a custom trigger type with the engine.
945    ///
946    /// Returns a [`TriggerTypeRef`] handle that can register triggers and
947    /// functions with compile-time validated types.
948    ///
949    /// # Examples
950    /// ```rust,no_run
951    /// # use iii_sdk::{III, RegisterTriggerType};
952    /// # struct MyHandler;
953    /// # #[async_trait::async_trait]
954    /// # impl iii_sdk::TriggerHandler for MyHandler {
955    /// #     async fn register_trigger(&self, _: iii_sdk::TriggerConfig) -> Result<(), iii_sdk::IIIError> { Ok(()) }
956    /// #     async fn unregister_trigger(&self, _: iii_sdk::TriggerConfig) -> Result<(), iii_sdk::IIIError> { Ok(()) }
957    /// # }
958    /// # #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] struct MyConfig { url: String }
959    /// # #[derive(serde::Deserialize, schemars::JsonSchema)] struct MyRequest { data: String }
960    /// # let iii = III::new("ws://localhost:49134");
961    /// let my_trigger = iii.register_trigger_type(
962    ///     RegisterTriggerType::new("my-trigger", "My custom trigger", MyHandler)
963    ///         .trigger_request_format::<MyConfig>()
964    ///         .call_request_format::<MyRequest>(),
965    /// );
966    ///
967    /// // Compile-time safe: config must be MyConfig, function input must be MyRequest
968    /// my_trigger.register_function("my::handler", |req: MyRequest| -> Result<serde_json::Value, String> {
969    ///     Ok(serde_json::json!({ "data": req.data }))
970    /// });
971    /// my_trigger.register_trigger("my::handler", MyConfig { url: "/hook".into() });
972    /// ```
973    pub fn register_trigger_type<H, C, R>(
974        &self,
975        registration: RegisterTriggerType<H, C, R>,
976    ) -> TriggerTypeRef<C, R>
977    where
978        H: TriggerHandler + 'static,
979    {
980        let message = RegisterTriggerTypeMessage {
981            id: registration.id,
982            description: registration.description,
983            trigger_request_format: registration.trigger_request_format,
984            call_request_format: registration.call_request_format,
985        };
986
987        let trigger_type_id = message.id.clone();
988
989        self.inner.trigger_types.lock_or_recover().insert(
990            message.id.clone(),
991            RemoteTriggerTypeData {
992                message: message.clone(),
993                handler: Arc::new(registration.handler),
994            },
995        );
996
997        let _ = self.send_message(message.to_message());
998
999        TriggerTypeRef {
1000            iii: self.clone(),
1001            trigger_type_id,
1002            _phantom: std::marker::PhantomData,
1003        }
1004    }
1005
1006    /// Unregister a previously registered trigger type.
1007    pub fn unregister_trigger_type(&self, id: impl Into<String>) {
1008        let id = id.into();
1009        self.inner.trigger_types.lock_or_recover().remove(&id);
1010        let msg = UnregisterTriggerTypeMessage { id };
1011        let _ = self.send_message(msg.to_message());
1012    }
1013
1014    /// Bind a trigger configuration to a registered function.
1015    ///
1016    /// # Arguments
1017    /// * `input` - Trigger registration input with trigger_type, function_id, and config.
1018    ///
1019    /// # Examples
1020    /// ```rust
1021    /// # use iii_sdk::{III, RegisterTriggerInput};
1022    /// # use serde_json::json;
1023    /// # let iii = III::new("ws://localhost:49134");
1024    /// let trigger = iii.register_trigger(RegisterTriggerInput {
1025    ///     trigger_type: "http".to_string(),
1026    ///     function_id: "greet".to_string(),
1027    ///     config: json!({ "api_path": "/greet", "http_method": "GET" }),
1028    ///     metadata: None,
1029    /// })?;
1030    /// // Later...
1031    /// trigger.unregister();
1032    /// # Ok::<(), iii_sdk::IIIError>(())
1033    /// ```
1034    pub fn register_trigger(&self, input: RegisterTriggerInput) -> Result<Trigger, IIIError> {
1035        let id = Uuid::new_v4().to_string();
1036        let message = RegisterTriggerMessage {
1037            id: id.clone(),
1038            trigger_type: input.trigger_type,
1039            function_id: input.function_id,
1040            config: input.config,
1041            metadata: input.metadata,
1042        };
1043
1044        self.inner
1045            .triggers
1046            .lock_or_recover()
1047            .insert(message.id.clone(), message.clone());
1048        let _ = self.send_message(message.to_message());
1049
1050        let iii = self.clone();
1051        let trigger_type = message.trigger_type.clone();
1052        let unregister_id = message.id.clone();
1053        let unregister_fn = Arc::new(move || {
1054            let _ = iii.inner.triggers.lock_or_recover().remove(&unregister_id);
1055            let msg = UnregisterTriggerMessage {
1056                id: unregister_id.clone(),
1057                trigger_type: trigger_type.clone(),
1058            };
1059            let _ = iii.send_message(msg.to_message());
1060        });
1061
1062        Ok(Trigger::new(unregister_fn))
1063    }
1064
1065    /// Invoke a remote function.
1066    ///
1067    /// The routing behavior depends on the `action` field of the request:
1068    /// - No action: synchronous -- waits for the function to return.
1069    /// - [`TriggerAction::Enqueue`] - async via named queue.
1070    /// - [`TriggerAction::Void`] — fire-and-forget.
1071    ///
1072    /// # Examples
1073    /// ```rust
1074    /// # use iii_sdk::{III, TriggerRequest, TriggerAction};
1075    /// # use serde_json::json;
1076    /// # async fn example(iii: &III) -> Result<(), iii_sdk::IIIError> {
1077    /// // Synchronous
1078    /// let result = iii.trigger(TriggerRequest {
1079    ///     function_id: "greet".to_string(),
1080    ///     payload: json!({"name": "World"}),
1081    ///     action: None,
1082    ///     timeout_ms: None,
1083    /// }).await?;
1084    ///
1085    /// // Fire-and-forget
1086    /// iii.trigger(TriggerRequest {
1087    ///     function_id: "notify".to_string(),
1088    ///     payload: json!({}),
1089    ///     action: Some(TriggerAction::Void),
1090    ///     timeout_ms: None,
1091    /// }).await?;
1092    ///
1093    /// // Enqueue
1094    /// let receipt = iii.trigger(TriggerRequest {
1095    ///     function_id: "iii::durable::publish".to_string(),
1096    ///     payload: json!({"topic": "test"}),
1097    ///     action: Some(TriggerAction::Enqueue { queue: "test".to_string() }),
1098    ///     timeout_ms: None,
1099    /// }).await?;
1100    ///
1101    /// # Ok(())
1102    /// # }
1103    /// ```
1104    pub async fn trigger(
1105        &self,
1106        request: impl Into<crate::protocol::TriggerRequest>,
1107    ) -> Result<Value, IIIError> {
1108        let req = request.into();
1109        let (tp, bg) = inject_trace_headers();
1110
1111        // Void is fire-and-forget — no invocation_id, no response
1112        if matches!(req.action, Some(TriggerAction::Void)) {
1113            self.send_message(Message::InvokeFunction {
1114                invocation_id: None,
1115                function_id: req.function_id,
1116                data: req.payload,
1117                traceparent: tp,
1118                baggage: bg,
1119                action: req.action,
1120            })?;
1121            return Ok(Value::Null);
1122        }
1123
1124        // Enqueue and default: use invocation_id to receive acknowledgement/result
1125        let timeout = Duration::from_millis(req.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
1126        let invocation_id = Uuid::new_v4();
1127        let (tx, rx) = oneshot::channel();
1128
1129        self.inner
1130            .pending
1131            .lock_or_recover()
1132            .insert(invocation_id, tx);
1133
1134        self.send_message(Message::InvokeFunction {
1135            invocation_id: Some(invocation_id),
1136            function_id: req.function_id,
1137            data: req.payload,
1138            traceparent: tp,
1139            baggage: bg,
1140            action: req.action,
1141        })?;
1142
1143        match tokio::time::timeout(timeout, rx).await {
1144            Ok(Ok(result)) => result,
1145            Ok(Err(_)) => Err(IIIError::NotConnected),
1146            Err(_) => {
1147                self.inner.pending.lock_or_recover().remove(&invocation_id);
1148                Err(IIIError::Timeout)
1149            }
1150        }
1151    }
1152
1153    /// Get the current connection state.
1154    pub fn get_connection_state(&self) -> IIIConnectionState {
1155        *self.inner.connection_state.lock_or_recover()
1156    }
1157
1158    fn set_connection_state(&self, state: IIIConnectionState) {
1159        let mut current = self.inner.connection_state.lock_or_recover();
1160        if *current == state {
1161            return;
1162        }
1163        *current = state;
1164    }
1165
1166    /// List all registered functions from the engine
1167    pub async fn list_functions(&self) -> Result<Vec<FunctionInfo>, IIIError> {
1168        let result = self
1169            .trigger(TriggerRequest {
1170                function_id: "engine::functions::list".to_string(),
1171                payload: serde_json::json!({}),
1172                action: None,
1173                timeout_ms: None,
1174            })
1175            .await?;
1176
1177        let functions = result
1178            .get("functions")
1179            .and_then(|v| serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok())
1180            .unwrap_or_default();
1181
1182        Ok(functions)
1183    }
1184
1185    /// Subscribe to function availability events
1186    /// Returns a guard that will unsubscribe when dropped
1187    pub fn on_functions_available<F>(&self, callback: F) -> FunctionsAvailableGuard
1188    where
1189        F: Fn(Vec<FunctionInfo>) + Send + Sync + 'static,
1190    {
1191        let callback = Arc::new(callback);
1192        let callback_id = self
1193            .inner
1194            .functions_available_callback_counter
1195            .fetch_add(1, Ordering::Relaxed);
1196
1197        self.inner
1198            .functions_available_callbacks
1199            .lock_or_recover()
1200            .insert(callback_id, callback);
1201
1202        // Set up trigger if not already done
1203        let mut trigger_guard = self.inner.functions_available_trigger.lock_or_recover();
1204        if trigger_guard.is_none() {
1205            // Get or create function path (reuse existing if trigger registration previously failed)
1206            let function_id = {
1207                let mut path_guard = self.inner.functions_available_function_id.lock_or_recover();
1208                if path_guard.is_none() {
1209                    let path = format!("iii.on_functions_available.{}", Uuid::new_v4());
1210                    *path_guard = Some(path.clone());
1211                    path
1212                } else {
1213                    path_guard.clone().unwrap()
1214                }
1215            };
1216
1217            // Register handler function only if it doesn't already exist
1218            let function_exists = self
1219                .inner
1220                .functions
1221                .lock_or_recover()
1222                .contains_key(&function_id);
1223            if !function_exists {
1224                let iii = self.clone();
1225                self.register_function_with(
1226                    RegisterFunctionMessage {
1227                        id: function_id.clone(),
1228                        description: None,
1229                        request_format: None,
1230                        response_format: None,
1231                        metadata: None,
1232                        invocation: None,
1233                    },
1234                    move |input: Value| {
1235                        let iii = iii.clone();
1236                        async move {
1237                            let functions = input
1238                                .get("functions")
1239                                .and_then(|v| {
1240                                    serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok()
1241                                })
1242                                .unwrap_or_default();
1243
1244                            let callbacks =
1245                                iii.inner.functions_available_callbacks.lock_or_recover();
1246                            for cb in callbacks.values() {
1247                                cb(functions.clone());
1248                            }
1249                            Ok(Value::Null)
1250                        }
1251                    },
1252                );
1253            }
1254
1255            match self.register_trigger(RegisterTriggerInput {
1256                trigger_type: "engine::functions-available".to_string(),
1257                function_id,
1258                config: serde_json::json!({}),
1259                metadata: None,
1260            }) {
1261                Ok(trigger) => {
1262                    *trigger_guard = Some(trigger);
1263                }
1264                Err(err) => {
1265                    tracing::warn!(error = %err, "Failed to register functions_available trigger");
1266                }
1267            }
1268        }
1269
1270        FunctionsAvailableGuard {
1271            iii: self.clone(),
1272            callback_id,
1273        }
1274    }
1275
1276    /// List all connected workers from the engine
1277    pub async fn list_workers(&self) -> Result<Vec<WorkerInfo>, IIIError> {
1278        let result = self
1279            .trigger(TriggerRequest {
1280                function_id: "engine::workers::list".to_string(),
1281                payload: serde_json::json!({}),
1282                action: None,
1283                timeout_ms: None,
1284            })
1285            .await?;
1286
1287        let workers = result
1288            .get("workers")
1289            .and_then(|v| serde_json::from_value::<Vec<WorkerInfo>>(v.clone()).ok())
1290            .unwrap_or_default();
1291
1292        Ok(workers)
1293    }
1294
1295    /// List all registered triggers from the engine
1296    pub async fn list_triggers(
1297        &self,
1298        include_internal: bool,
1299    ) -> Result<Vec<TriggerInfo>, IIIError> {
1300        let result = self
1301            .trigger(TriggerRequest {
1302                function_id: "engine::triggers::list".to_string(),
1303                payload: serde_json::json!({ "include_internal": include_internal }),
1304                action: None,
1305                timeout_ms: None,
1306            })
1307            .await?;
1308
1309        let triggers = result
1310            .get("triggers")
1311            .and_then(|v| serde_json::from_value::<Vec<TriggerInfo>>(v.clone()).ok())
1312            .unwrap_or_default();
1313
1314        Ok(triggers)
1315    }
1316
1317    /// List all registered trigger types from the engine with their
1318    /// `trigger_request_format` and `call_request_format` schemas.
1319    pub async fn list_trigger_types(
1320        &self,
1321        include_internal: bool,
1322    ) -> Result<Vec<TriggerTypeInfo>, IIIError> {
1323        let result = self
1324            .trigger(TriggerRequest {
1325                function_id: "engine::trigger-types::list".to_string(),
1326                payload: serde_json::json!({ "include_internal": include_internal }),
1327                action: None,
1328                timeout_ms: None,
1329            })
1330            .await?;
1331
1332        let trigger_types = result
1333            .get("trigger_types")
1334            .and_then(|v| serde_json::from_value::<Vec<TriggerTypeInfo>>(v.clone()).ok())
1335            .unwrap_or_default();
1336
1337        Ok(trigger_types)
1338    }
1339
1340    /// Create a streaming channel pair for worker-to-worker data transfer.
1341    ///
1342    /// Returns a `Channel` with writer, reader, and their serializable refs
1343    /// that can be passed as fields in invocation data to other functions.
1344    pub async fn create_channel(&self, buffer_size: Option<usize>) -> Result<Channel, IIIError> {
1345        let result = self
1346            .trigger(TriggerRequest {
1347                function_id: "engine::channels::create".to_string(),
1348                payload: serde_json::json!({ "buffer_size": buffer_size }),
1349                action: None,
1350                timeout_ms: None,
1351            })
1352            .await?;
1353
1354        let writer_ref: StreamChannelRef = serde_json::from_value(
1355            result
1356                .get("writer")
1357                .cloned()
1358                .ok_or_else(|| IIIError::Serde("missing 'writer' in channel response".into()))?,
1359        )
1360        .map_err(|e| IIIError::Serde(e.to_string()))?;
1361
1362        let reader_ref: StreamChannelRef = serde_json::from_value(
1363            result
1364                .get("reader")
1365                .cloned()
1366                .ok_or_else(|| IIIError::Serde("missing 'reader' in channel response".into()))?,
1367        )
1368        .map_err(|e| IIIError::Serde(e.to_string()))?;
1369
1370        Ok(Channel {
1371            writer: ChannelWriter::new(&self.inner.address, &writer_ref),
1372            reader: ChannelReader::new(&self.inner.address, &reader_ref),
1373            writer_ref,
1374            reader_ref,
1375        })
1376    }
1377
1378    /// Register this worker's metadata with the engine (called automatically on connect)
1379    fn register_worker_metadata(&self) {
1380        if let Some(mut metadata) = self.inner.worker_metadata.lock_or_recover().clone() {
1381            let fw = metadata
1382                .telemetry
1383                .as_ref()
1384                .and_then(|t| t.framework.as_deref())
1385                .unwrap_or("");
1386            if fw.is_empty() {
1387                let telem = metadata.telemetry.get_or_insert_with(Default::default);
1388                telem.framework = Some("iii-rust".to_string());
1389            }
1390            if let Ok(value) = serde_json::to_value(metadata) {
1391                let _ = self.send_message(Message::InvokeFunction {
1392                    invocation_id: None,
1393                    function_id: "engine::workers::register".to_string(),
1394                    data: value,
1395                    traceparent: None,
1396                    baggage: None,
1397                    action: Some(TriggerAction::Void),
1398                });
1399            }
1400        }
1401    }
1402
1403    fn send_message(&self, message: Message) -> Result<(), IIIError> {
1404        if !self.inner.running.load(Ordering::SeqCst) {
1405            return Ok(());
1406        }
1407
1408        self.inner
1409            .outbound
1410            .send(Outbound::Message(message))
1411            .map_err(|_| IIIError::NotConnected)
1412    }
1413
1414    async fn run_connection(&self, mut rx: mpsc::UnboundedReceiver<Outbound>) {
1415        let mut queue: Vec<Message> = Vec::new();
1416        let mut has_connected_before = false;
1417
1418        while self.inner.running.load(Ordering::SeqCst) {
1419            self.set_connection_state(if has_connected_before {
1420                IIIConnectionState::Reconnecting
1421            } else {
1422                IIIConnectionState::Connecting
1423            });
1424
1425            let custom_headers = self.inner.headers.lock_or_recover().clone();
1426
1427            let connect_result = if let Some(ref h) = custom_headers {
1428                use tokio_tungstenite::tungstenite::client::IntoClientRequest;
1429                use tokio_tungstenite::tungstenite::http;
1430                let mut request = self
1431                    .inner
1432                    .address
1433                    .as_str()
1434                    .into_client_request()
1435                    .expect("valid ws request");
1436                for (k, v) in h {
1437                    if let (Ok(name), Ok(val)) = (
1438                        http::header::HeaderName::from_bytes(k.as_bytes()),
1439                        http::header::HeaderValue::from_str(v),
1440                    ) {
1441                        request.headers_mut().insert(name, val);
1442                    }
1443                }
1444                connect_async(request).await
1445            } else {
1446                connect_async(&self.inner.address).await
1447            };
1448
1449            match connect_result {
1450                Ok((stream, _)) => {
1451                    tracing::info!(address = %self.inner.address, "iii connected");
1452                    has_connected_before = true;
1453                    self.set_connection_state(IIIConnectionState::Connected);
1454                    let (mut ws_tx, mut ws_rx) = stream.split();
1455
1456                    queue.extend(self.collect_registrations());
1457                    Self::dedupe_registrations(&mut queue);
1458                    if let Err(err) = self.flush_queue(&mut ws_tx, &mut queue).await {
1459                        tracing::warn!(error = %err, "failed to flush queue");
1460                        sleep(Duration::from_secs(2)).await;
1461                        continue;
1462                    }
1463
1464                    // Auto-register worker metadata on connect (like Node SDK)
1465                    self.register_worker_metadata();
1466
1467                    let mut should_reconnect = false;
1468
1469                    while self.inner.running.load(Ordering::SeqCst) && !should_reconnect {
1470                        tokio::select! {
1471                            outgoing = rx.recv() => {
1472                                match outgoing {
1473                                    Some(Outbound::Message(message)) => {
1474                                        if let Err(err) = self.send_ws(&mut ws_tx, &message).await {
1475                                            tracing::warn!(error = %err, "send failed; reconnecting");
1476                                            queue.push(message);
1477                                            should_reconnect = true;
1478                                        }
1479                                    }
1480                                    Some(Outbound::Shutdown) => {
1481                                        self.inner.running.store(false, Ordering::SeqCst);
1482                                        return;
1483                                    }
1484                                    None => {
1485                                        self.inner.running.store(false, Ordering::SeqCst);
1486                                        return;
1487                                    }
1488                                }
1489                            }
1490                            incoming = ws_rx.next() => {
1491                                match incoming {
1492                                    Some(Ok(frame)) => {
1493                                        if let Err(err) = self.handle_frame(frame) {
1494                                            tracing::warn!(error = %err, "failed to handle frame");
1495                                        }
1496                                    }
1497                                    Some(Err(err)) => {
1498                                        tracing::warn!(error = %err, "websocket receive error");
1499                                        should_reconnect = true;
1500                                    }
1501                                    None => {
1502                                        should_reconnect = true;
1503                                    }
1504                                }
1505                            }
1506                        }
1507                    }
1508                }
1509                Err(err) => {
1510                    tracing::warn!(error = %err, "failed to connect; retrying");
1511                }
1512            }
1513
1514            if self.inner.running.load(Ordering::SeqCst) {
1515                sleep(Duration::from_secs(2)).await;
1516            }
1517        }
1518    }
1519
1520    fn collect_registrations(&self) -> Vec<Message> {
1521        let mut messages = Vec::new();
1522
1523        for trigger_type in self.inner.trigger_types.lock_or_recover().values() {
1524            messages.push(trigger_type.message.to_message());
1525        }
1526
1527        for service in self.inner.services.lock_or_recover().values() {
1528            messages.push(service.to_message());
1529        }
1530
1531        for function in self.inner.functions.lock_or_recover().values() {
1532            messages.push(function.message.to_message());
1533        }
1534
1535        for trigger in self.inner.triggers.lock_or_recover().values() {
1536            messages.push(trigger.to_message());
1537        }
1538
1539        messages
1540    }
1541
1542    fn dedupe_registrations(queue: &mut Vec<Message>) {
1543        let mut seen = HashSet::new();
1544        let mut deduped_rev = Vec::with_capacity(queue.len());
1545
1546        for message in queue.iter().rev() {
1547            let key = match message {
1548                Message::RegisterTriggerType { id, .. } => format!("trigger_type:{id}"),
1549                Message::RegisterTrigger { id, .. } => format!("trigger:{id}"),
1550                Message::RegisterFunction { id, .. } => {
1551                    format!("function:{id}")
1552                }
1553                Message::RegisterService { id, .. } => format!("service:{id}"),
1554                _ => {
1555                    deduped_rev.push(message.clone());
1556                    continue;
1557                }
1558            };
1559
1560            if seen.insert(key) {
1561                deduped_rev.push(message.clone());
1562            }
1563        }
1564
1565        deduped_rev.reverse();
1566        *queue = deduped_rev;
1567    }
1568
1569    async fn flush_queue(
1570        &self,
1571        ws_tx: &mut WsTx,
1572        queue: &mut Vec<Message>,
1573    ) -> Result<(), IIIError> {
1574        let mut drained = Vec::new();
1575        std::mem::swap(queue, &mut drained);
1576
1577        let mut iter = drained.into_iter();
1578        while let Some(message) = iter.next() {
1579            if let Err(err) = self.send_ws(ws_tx, &message).await {
1580                queue.push(message);
1581                queue.extend(iter);
1582                return Err(err);
1583            }
1584        }
1585
1586        Ok(())
1587    }
1588
1589    async fn send_ws(&self, ws_tx: &mut WsTx, message: &Message) -> Result<(), IIIError> {
1590        let payload = serde_json::to_string(message)?;
1591        ws_tx.send(WsMessage::Text(payload.into())).await?;
1592        Ok(())
1593    }
1594
1595    fn handle_frame(&self, frame: WsMessage) -> Result<(), IIIError> {
1596        match frame {
1597            WsMessage::Text(text) => self.handle_message(&text),
1598            WsMessage::Binary(bytes) => {
1599                let text = String::from_utf8_lossy(&bytes).to_string();
1600                self.handle_message(&text)
1601            }
1602            _ => Ok(()),
1603        }
1604    }
1605
1606    fn handle_message(&self, payload: &str) -> Result<(), IIIError> {
1607        let message: Message = serde_json::from_str(payload)?;
1608
1609        match message {
1610            Message::InvocationResult {
1611                invocation_id,
1612                result,
1613                error,
1614                ..
1615            } => {
1616                self.handle_invocation_result(invocation_id, result, error);
1617            }
1618            Message::InvokeFunction {
1619                invocation_id,
1620                function_id,
1621                data,
1622                traceparent,
1623                baggage,
1624                action: _,
1625            } => {
1626                self.handle_invoke_function(invocation_id, function_id, data, traceparent, baggage);
1627            }
1628            Message::RegisterTrigger {
1629                id,
1630                trigger_type,
1631                function_id,
1632                config,
1633                metadata,
1634            } => {
1635                self.handle_register_trigger(id, trigger_type, function_id, config, metadata);
1636            }
1637            Message::Ping => {
1638                let _ = self.send_message(Message::Pong);
1639            }
1640            Message::WorkerRegistered { worker_id } => {
1641                tracing::debug!(worker_id = %worker_id, "Worker registered");
1642            }
1643            _ => {}
1644        }
1645
1646        Ok(())
1647    }
1648
1649    fn handle_invocation_result(
1650        &self,
1651        invocation_id: Uuid,
1652        result: Option<Value>,
1653        error: Option<ErrorBody>,
1654    ) {
1655        let sender = self.inner.pending.lock_or_recover().remove(&invocation_id);
1656        if let Some(sender) = sender {
1657            let result = match error {
1658                Some(error) => Err(IIIError::Remote {
1659                    code: error.code,
1660                    message: error.message,
1661                    stacktrace: error.stacktrace,
1662                }),
1663                None => Ok(result.unwrap_or(Value::Null)),
1664            };
1665            let _ = sender.send(result);
1666        }
1667    }
1668
1669    fn handle_invoke_function(
1670        &self,
1671        invocation_id: Option<Uuid>,
1672        function_id: String,
1673        data: Value,
1674        traceparent: Option<String>,
1675        baggage: Option<String>,
1676    ) {
1677        tracing::debug!(function_id = %function_id, traceparent = ?traceparent, baggage = ?baggage, "Invoking function");
1678
1679        let func_data = self
1680            .inner
1681            .functions
1682            .lock_or_recover()
1683            .get(&function_id)
1684            .cloned();
1685        let handler = func_data.as_ref().and_then(|d| d.handler.clone());
1686
1687        let Some(handler) = handler else {
1688            let (code, message) = match &func_data {
1689                Some(_) => (
1690                    "function_not_invokable".to_string(),
1691                    "Function is HTTP-invoked and cannot be invoked locally".to_string(),
1692                ),
1693                None => (
1694                    "function_not_found".to_string(),
1695                    "Function not found".to_string(),
1696                ),
1697            };
1698            tracing::warn!(function_id = %function_id, "Invocation: {}", message);
1699
1700            if let Some(invocation_id) = invocation_id {
1701                let (resp_tp, resp_bg) = inject_trace_headers();
1702
1703                let error = ErrorBody {
1704                    code,
1705                    message,
1706                    stacktrace: None,
1707                };
1708                let result = self.send_message(Message::InvocationResult {
1709                    invocation_id,
1710                    function_id,
1711                    result: None,
1712                    error: Some(error),
1713                    traceparent: resp_tp,
1714                    baggage: resp_bg,
1715                });
1716
1717                if let Err(err) = result {
1718                    tracing::warn!(error = %err, "error sending invocation result");
1719                }
1720            }
1721            return;
1722        };
1723
1724        let iii = self.clone();
1725
1726        tokio::spawn(async move {
1727            // Extract incoming trace context and create a span for this invocation.
1728            // This ensures the handler and any outbound calls it makes (e.g.
1729            // invoke_function_with_timeout) are linked as children of the caller's trace.
1730            // We use FutureExt::with_context() instead of cx.attach() because
1731            // ContextGuard is !Send and can't be held across .await in tokio::spawn.
1732            let otel_cx = {
1733                use crate::telemetry::context::extract_context;
1734                use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer};
1735
1736                let parent_cx = extract_context(traceparent.as_deref(), baggage.as_deref());
1737                let tracer = opentelemetry::global::tracer("iii-rust-sdk");
1738                let span = tracer
1739                    .span_builder(format!("call {}", function_id))
1740                    .with_kind(SpanKind::Server)
1741                    .start_with_context(&tracer, &parent_cx);
1742                parent_cx.with_span(span)
1743            };
1744
1745            let result = {
1746                use opentelemetry::trace::FutureExt as OtelFutureExt;
1747                handler(data).with_context(otel_cx.clone()).await
1748            };
1749
1750            // Record span status based on result
1751            let mut error_stacktrace: Option<String> = None;
1752            {
1753                use opentelemetry::KeyValue;
1754                use opentelemetry::trace::{Status, TraceContextExt};
1755                let span = otel_cx.span();
1756                match &result {
1757                    Ok(_) => span.set_status(Status::Ok),
1758                    Err(err) => {
1759                        let (exc_type, exc_message, stacktrace) = match err {
1760                            IIIError::Remote {
1761                                code,
1762                                message,
1763                                stacktrace,
1764                            } => (
1765                                code.clone(),
1766                                message.clone(),
1767                                stacktrace.clone().unwrap_or_else(|| {
1768                                    std::backtrace::Backtrace::force_capture().to_string()
1769                                }),
1770                            ),
1771                            other => (
1772                                "InvocationError".to_string(),
1773                                other.to_string(),
1774                                std::backtrace::Backtrace::force_capture().to_string(),
1775                            ),
1776                        };
1777                        span.set_status(Status::error(exc_message.clone()));
1778                        span.add_event(
1779                            "exception",
1780                            vec![
1781                                KeyValue::new("exception.type", exc_type),
1782                                KeyValue::new("exception.message", exc_message),
1783                                KeyValue::new("exception.stacktrace", stacktrace.clone()),
1784                            ],
1785                        );
1786                        error_stacktrace = Some(stacktrace);
1787                    }
1788                }
1789            }
1790
1791            if let Some(invocation_id) = invocation_id {
1792                // Inject trace context from our span into the response.
1793                // We briefly attach the otel context (no .await crossing)
1794                // so inject_traceparent/inject_baggage can read it.
1795                let (resp_tp, resp_bg) = {
1796                    let _guard = otel_cx.attach();
1797                    inject_trace_headers()
1798                };
1799
1800                let message = match result {
1801                    Ok(value) => Message::InvocationResult {
1802                        invocation_id,
1803                        function_id,
1804                        result: Some(value),
1805                        error: None,
1806                        traceparent: resp_tp,
1807                        baggage: resp_bg,
1808                    },
1809                    Err(err) => {
1810                        let error_body = match err {
1811                            IIIError::Remote {
1812                                code,
1813                                message,
1814                                stacktrace,
1815                            } => ErrorBody {
1816                                code,
1817                                message,
1818                                stacktrace: stacktrace.or(error_stacktrace).or_else(|| {
1819                                    Some(std::backtrace::Backtrace::force_capture().to_string())
1820                                }),
1821                            },
1822                            other => ErrorBody {
1823                                code: "invocation_failed".to_string(),
1824                                message: other.to_string(),
1825                                stacktrace: error_stacktrace.or_else(|| {
1826                                    Some(std::backtrace::Backtrace::force_capture().to_string())
1827                                }),
1828                            },
1829                        };
1830                        Message::InvocationResult {
1831                            invocation_id,
1832                            function_id,
1833                            result: None,
1834                            error: Some(error_body),
1835                            traceparent: resp_tp,
1836                            baggage: resp_bg,
1837                        }
1838                    }
1839                };
1840
1841                let _ = iii.send_message(message);
1842            } else if let Err(err) = result {
1843                tracing::warn!(error = %err, "error handling async invocation");
1844            }
1845        });
1846    }
1847
1848    fn handle_register_trigger(
1849        &self,
1850        id: String,
1851        trigger_type: String,
1852        function_id: String,
1853        config: Value,
1854        metadata: Option<Value>,
1855    ) {
1856        let handler = self
1857            .inner
1858            .trigger_types
1859            .lock_or_recover()
1860            .get(&trigger_type)
1861            .map(|data| data.handler.clone());
1862
1863        let iii = self.clone();
1864
1865        tokio::spawn(async move {
1866            let message = if let Some(handler) = handler {
1867                let config = TriggerConfig {
1868                    id: id.clone(),
1869                    function_id: function_id.clone(),
1870                    config,
1871                    metadata,
1872                };
1873
1874                match handler.register_trigger(config).await {
1875                    Ok(()) => Message::TriggerRegistrationResult {
1876                        id,
1877                        trigger_type,
1878                        function_id,
1879                        error: None,
1880                    },
1881                    Err(err) => Message::TriggerRegistrationResult {
1882                        id,
1883                        trigger_type,
1884                        function_id,
1885                        error: Some(ErrorBody {
1886                            code: "trigger_registration_failed".to_string(),
1887                            message: err.to_string(),
1888                            stacktrace: None,
1889                        }),
1890                    },
1891                }
1892            } else {
1893                Message::TriggerRegistrationResult {
1894                    id,
1895                    trigger_type,
1896                    function_id,
1897                    error: Some(ErrorBody {
1898                        code: "trigger_type_not_found".to_string(),
1899                        message: "Trigger type not found".to_string(),
1900                        stacktrace: None,
1901                    }),
1902                }
1903            };
1904
1905            let _ = iii.send_message(message);
1906        });
1907    }
1908}
1909
1910#[cfg(test)]
1911mod tests {
1912    use std::collections::HashMap;
1913
1914    use serde_json::json;
1915
1916    use super::*;
1917    use crate::{
1918        InitOptions,
1919        protocol::{HttpInvocationConfig, HttpMethod, RegisterTriggerInput},
1920        register_worker,
1921    };
1922
1923    #[tokio::test]
1924    async fn register_trigger_unregister_removes_entry() {
1925        let iii = register_worker("ws://localhost:1234", InitOptions::default());
1926        let trigger = iii
1927            .register_trigger(RegisterTriggerInput {
1928                trigger_type: "demo".to_string(),
1929                function_id: "functions.echo".to_string(),
1930                config: json!({ "foo": "bar" }),
1931                metadata: None,
1932            })
1933            .unwrap();
1934
1935        assert_eq!(iii.inner.triggers.lock().unwrap().len(), 1);
1936
1937        trigger.unregister();
1938
1939        assert_eq!(iii.inner.triggers.lock().unwrap().len(), 0);
1940    }
1941
1942    #[tokio::test]
1943    async fn register_function_with_http_config_stores_and_unregister_removes() {
1944        let iii = register_worker("ws://localhost:1234", InitOptions::default());
1945        let config = HttpInvocationConfig {
1946            url: "https://example.com/invoke".to_string(),
1947            method: HttpMethod::Post,
1948            timeout_ms: Some(30000),
1949            headers: HashMap::new(),
1950            auth: None,
1951        };
1952
1953        let func_ref = iii.register_function_with(
1954            RegisterFunctionMessage {
1955                id: "external::my_lambda".to_string(),
1956                description: None,
1957                request_format: None,
1958                response_format: None,
1959                metadata: None,
1960                invocation: None,
1961            },
1962            config,
1963        );
1964
1965        assert_eq!(func_ref.id, "external::my_lambda");
1966        assert_eq!(iii.inner.functions.lock().unwrap().len(), 1);
1967
1968        func_ref.unregister();
1969
1970        assert_eq!(iii.inner.functions.lock().unwrap().len(), 0);
1971    }
1972
1973    #[tokio::test]
1974    #[should_panic(expected = "id is required")]
1975    async fn register_function_rejects_empty_id() {
1976        let iii = register_worker("ws://localhost:1234", InitOptions::default());
1977        let config = HttpInvocationConfig {
1978            url: "https://example.com/invoke".to_string(),
1979            method: HttpMethod::Post,
1980            timeout_ms: None,
1981            headers: HashMap::new(),
1982            auth: None,
1983        };
1984
1985        iii.register_function_with(
1986            RegisterFunctionMessage {
1987                id: "".to_string(),
1988                description: None,
1989                request_format: None,
1990                response_format: None,
1991                metadata: None,
1992                invocation: None,
1993            },
1994            config,
1995        );
1996    }
1997
1998    #[tokio::test]
1999    async fn invoke_function_times_out_and_clears_pending() {
2000        let iii = register_worker("ws://localhost:1234", InitOptions::default());
2001        let result = iii
2002            .trigger(TriggerRequest {
2003                function_id: "functions.echo".to_string(),
2004                payload: json!({ "a": 1 }),
2005                action: None,
2006                timeout_ms: Some(10),
2007            })
2008            .await;
2009
2010        assert!(matches!(result, Err(IIIError::Timeout)));
2011        assert!(iii.inner.pending.lock().unwrap().is_empty());
2012    }
2013}