Skip to main content

iii_sdk/
iii.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::{
4        Arc, Mutex, MutexGuard,
5        atomic::{AtomicBool, AtomicUsize, Ordering},
6    },
7    time::Duration,
8};
9
10/// Extension trait for Mutex that recovers from poisoning instead of panicking.
11/// This is safe when the protected data is still valid after a panic in another thread.
12trait MutexExt<T> {
13    fn lock_or_recover(&self) -> MutexGuard<'_, T>;
14}
15
16impl<T> MutexExt<T> for Mutex<T> {
17    fn lock_or_recover(&self) -> MutexGuard<'_, T> {
18        self.lock().unwrap_or_else(|e| e.into_inner())
19    }
20}
21
22use futures_util::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use tokio::{
26    sync::{mpsc, oneshot},
27    time::sleep,
28};
29use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
30use uuid::Uuid;
31
32const SDK_VERSION: &str = env!("CARGO_PKG_VERSION");
33
34use crate::{
35    context::{Context, with_context},
36    error::IIIError,
37    logger::{Logger, LoggerInvoker},
38    protocol::{
39        ErrorBody, Message, RegisterFunctionMessage, RegisterServiceMessage,
40        RegisterTriggerMessage, RegisterTriggerTypeMessage, UnregisterTriggerMessage,
41    },
42    triggers::{Trigger, TriggerConfig, TriggerHandler},
43    types::{RemoteFunctionData, RemoteFunctionHandler, RemoteTriggerTypeData},
44};
45
46#[cfg(feature = "otel")]
47use crate::telemetry;
48#[cfg(feature = "otel")]
49use crate::telemetry::types::OtelConfig;
50
51const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
52
53/// Worker information returned by `engine.workers.list`
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct WorkerInfo {
56    pub id: String,
57    pub name: Option<String>,
58    pub runtime: Option<String>,
59    pub version: Option<String>,
60    pub os: Option<String>,
61    pub ip_address: Option<String>,
62    pub status: String,
63    pub connected_at_ms: u64,
64    pub function_count: usize,
65    pub functions: Vec<String>,
66    pub active_invocations: usize,
67}
68
69/// Function information returned by `engine.functions.list`
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct FunctionInfo {
72    pub function_id: String,
73    pub description: Option<String>,
74    pub request_format: Option<Value>,
75    pub response_format: Option<Value>,
76    pub metadata: Option<Value>,
77}
78
79/// Trigger information returned by `engine.triggers.list`
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TriggerInfo {
82    pub id: String,
83    pub trigger_type: String,
84    pub function_id: String,
85    pub config: Value,
86}
87
88/// Worker metadata for auto-registration
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct WorkerMetadata {
91    pub runtime: String,
92    pub version: String,
93    pub name: String,
94    pub os: String,
95}
96
97impl Default for WorkerMetadata {
98    fn default() -> Self {
99        let hostname = hostname::get()
100            .map(|h| h.to_string_lossy().to_string())
101            .unwrap_or_else(|_| "unknown".to_string());
102        let pid = std::process::id();
103        let os_info = format!(
104            "{} {} ({})",
105            std::env::consts::OS,
106            std::env::consts::ARCH,
107            std::env::consts::FAMILY
108        );
109
110        Self {
111            runtime: "rust".to_string(),
112            version: SDK_VERSION.to_string(),
113            name: format!("{}:{}", hostname, pid),
114            os: os_info,
115        }
116    }
117}
118
119enum Outbound {
120    Message(Message),
121    Shutdown,
122}
123
124type PendingInvocation = oneshot::Sender<Result<Value, IIIError>>;
125
126// WebSocket transmitter type alias
127type WsTx = futures_util::stream::SplitSink<
128    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
129    WsMessage,
130>;
131
132/// Inject trace context headers for outbound messages.
133/// Returns (traceparent, baggage) - both None when otel feature is disabled.
134#[cfg(feature = "otel")]
135fn inject_trace_headers() -> (Option<String>, Option<String>) {
136    use crate::telemetry::context;
137    (context::inject_traceparent(), context::inject_baggage())
138}
139
140#[cfg(not(feature = "otel"))]
141fn inject_trace_headers() -> (Option<String>, Option<String>) {
142    (None, None)
143}
144
145/// Callback function type for functions available events
146pub type FunctionsAvailableCallback = Arc<dyn Fn(Vec<FunctionInfo>) + Send + Sync>;
147
148struct IIIInner {
149    address: String,
150    outbound: mpsc::UnboundedSender<Outbound>,
151    receiver: Mutex<Option<mpsc::UnboundedReceiver<Outbound>>>,
152    running: AtomicBool,
153    started: AtomicBool,
154    pending: Mutex<HashMap<Uuid, PendingInvocation>>,
155    functions: Mutex<HashMap<String, RemoteFunctionData>>,
156    trigger_types: Mutex<HashMap<String, RemoteTriggerTypeData>>,
157    triggers: Mutex<HashMap<String, RegisterTriggerMessage>>,
158    services: Mutex<HashMap<String, RegisterServiceMessage>>,
159    worker_metadata: Mutex<Option<WorkerMetadata>>,
160    functions_available_callbacks: Mutex<HashMap<usize, FunctionsAvailableCallback>>,
161    functions_available_callback_counter: AtomicUsize,
162    functions_available_function_id: Mutex<Option<String>>,
163    functions_available_trigger: Mutex<Option<Trigger>>,
164    #[cfg(feature = "otel")]
165    otel_config: Mutex<Option<OtelConfig>>,
166}
167
168#[derive(Clone)]
169pub struct III {
170    inner: Arc<IIIInner>,
171}
172
173/// Guard that unsubscribes from functions available events when dropped
174pub struct FunctionsAvailableGuard {
175    iii: III,
176    callback_id: usize,
177}
178
179impl Drop for FunctionsAvailableGuard {
180    fn drop(&mut self) {
181        let mut callbacks = self
182            .iii
183            .inner
184            .functions_available_callbacks
185            .lock_or_recover();
186        callbacks.remove(&self.callback_id);
187
188        if callbacks.is_empty() {
189            let mut trigger = self.iii.inner.functions_available_trigger.lock_or_recover();
190            if let Some(trigger) = trigger.take() {
191                trigger.unregister();
192            }
193        }
194    }
195}
196
197impl III {
198    /// Create a new III with default worker metadata (auto-detected runtime, os, hostname)
199    pub fn new(address: &str) -> Self {
200        Self::with_metadata(address, WorkerMetadata::default())
201    }
202
203    /// Create a new III with custom worker metadata
204    pub fn with_metadata(address: &str, metadata: WorkerMetadata) -> Self {
205        let (tx, rx) = mpsc::unbounded_channel();
206        let inner = IIIInner {
207            address: address.into(),
208            outbound: tx,
209            receiver: Mutex::new(Some(rx)),
210            running: AtomicBool::new(false),
211            started: AtomicBool::new(false),
212            pending: Mutex::new(HashMap::new()),
213            functions: Mutex::new(HashMap::new()),
214            trigger_types: Mutex::new(HashMap::new()),
215            triggers: Mutex::new(HashMap::new()),
216            services: Mutex::new(HashMap::new()),
217            worker_metadata: Mutex::new(Some(metadata)),
218            functions_available_callbacks: Mutex::new(HashMap::new()),
219            functions_available_callback_counter: AtomicUsize::new(0),
220            functions_available_function_id: Mutex::new(None),
221            functions_available_trigger: Mutex::new(None),
222            #[cfg(feature = "otel")]
223            otel_config: Mutex::new(None),
224        };
225        Self {
226            inner: Arc::new(inner),
227        }
228    }
229
230    /// Set custom worker metadata (call before connect)
231    pub fn set_metadata(&self, metadata: WorkerMetadata) {
232        *self.inner.worker_metadata.lock_or_recover() = Some(metadata);
233    }
234
235    /// Set OpenTelemetry configuration (call before connect)
236    #[cfg(feature = "otel")]
237    pub fn set_otel_config(&self, config: OtelConfig) {
238        *self.inner.otel_config.lock_or_recover() = Some(config);
239    }
240
241    pub async fn connect(&self) -> Result<(), IIIError> {
242        if self.inner.started.swap(true, Ordering::SeqCst) {
243            return Ok(());
244        }
245
246        let receiver = self.inner.receiver.lock_or_recover().take();
247        let Some(rx) = receiver else {
248            return Ok(());
249        };
250
251        let iii = self.clone();
252
253        tokio::spawn(async move {
254            iii.inner.running.store(true, Ordering::SeqCst);
255            iii.run_connection(rx).await;
256        });
257
258        // Initialize OpenTelemetry if configured.
259        // NOTE: This runs after the connection spawn, so the first few function
260        // invocations may not carry tracing context. The global tracer returns a
261        // no-op until initialization completes, so no panics occur — traces
262        // simply won't appear for those early calls.
263        #[cfg(feature = "otel")]
264        {
265            let config = self.inner.otel_config.lock_or_recover().take();
266            if let Some(mut config) = config {
267                // Default engine_ws_url to the III address if not set
268                if config.engine_ws_url.is_none() {
269                    config.engine_ws_url = Some(self.inner.address.clone());
270                }
271                telemetry::init_otel(config).await;
272            }
273        }
274
275        Ok(())
276    }
277
278    /// Shutdown the III client.
279    ///
280    /// This stops the connection loop and sends a shutdown signal.
281    /// If the `otel` feature is enabled, this will spawn a background task
282    /// to flush telemetry data, but does NOT wait for it to complete.
283    /// For guaranteed telemetry flush, use `shutdown_async()` instead.
284    #[deprecated(note = "Use shutdown_async() for guaranteed telemetry flush")]
285    pub fn shutdown(&self) {
286        self.inner.running.store(false, Ordering::SeqCst);
287        let _ = self.inner.outbound.send(Outbound::Shutdown);
288
289        // Shutdown OpenTelemetry (best-effort, does not wait for flush)
290        #[cfg(feature = "otel")]
291        {
292            tracing::warn!(
293                "shutdown() does not await telemetry flush; use shutdown_async() instead"
294            );
295            tokio::spawn(async {
296                telemetry::shutdown_otel().await;
297            });
298        }
299    }
300
301    /// Shutdown the III client and flush all pending telemetry data.
302    ///
303    /// This method stops the connection loop and sends a shutdown signal.
304    /// When the `otel` feature is enabled, it additionally awaits the
305    /// OpenTelemetry flush, ensuring all spans, metrics, and logs are
306    /// exported before returning.
307    pub async fn shutdown_async(&self) {
308        self.inner.running.store(false, Ordering::SeqCst);
309        let _ = self.inner.outbound.send(Outbound::Shutdown);
310
311        #[cfg(feature = "otel")]
312        telemetry::shutdown_otel().await;
313    }
314
315    pub fn register_function<F, Fut>(&self, id: impl Into<String>, handler: F)
316    where
317        F: Fn(Value) -> Fut + Send + Sync + 'static,
318        Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
319    {
320        let message = RegisterFunctionMessage {
321            id: id.into(),
322            description: None,
323            request_format: None,
324            response_format: None,
325            metadata: None,
326        };
327
328        self.register_function_with(message, handler);
329    }
330
331    pub fn register_function_with_description<F, Fut>(
332        &self,
333        id: impl Into<String>,
334        description: impl Into<String>,
335        handler: F,
336    ) where
337        F: Fn(Value) -> Fut + Send + Sync + 'static,
338        Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
339    {
340        let message = RegisterFunctionMessage {
341            id: id.into(),
342            description: Some(description.into()),
343            request_format: None,
344            response_format: None,
345            metadata: None,
346        };
347
348        self.register_function_with(message, handler);
349    }
350
351    pub fn register_function_with<F, Fut>(&self, message: RegisterFunctionMessage, handler: F)
352    where
353        F: Fn(Value) -> Fut + Send + Sync + 'static,
354        Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
355    {
356        let function_id = message.id.clone();
357        let iii = self.clone();
358
359        let user_handler = Arc::new(move |input: Value| Box::pin(handler(input)));
360
361        let wrapped_handler: RemoteFunctionHandler = Arc::new(move |input: Value| {
362            let function_id = function_id.clone();
363            let iii = iii.clone();
364            let user_handler = user_handler.clone();
365
366            Box::pin(async move {
367                let invoker: LoggerInvoker = Arc::new(move |path, params| {
368                    let _ = iii.call_void(path, params);
369                });
370
371                let logger = Logger::new(
372                    Some(invoker),
373                    Some(Uuid::new_v4().to_string()),
374                    Some(function_id.clone()),
375                );
376                let context = Context { logger };
377
378                with_context(context, || user_handler(input)).await
379            })
380        });
381
382        let data = RemoteFunctionData {
383            message: message.clone(),
384            handler: wrapped_handler,
385        };
386
387        self.inner
388            .functions
389            .lock_or_recover()
390            .insert(message.id.clone(), data);
391        let _ = self.send_message(message.to_message());
392    }
393
394    pub fn register_service(&self, id: impl Into<String>, description: Option<String>) {
395        let id = id.into();
396        let message = RegisterServiceMessage {
397            id: id.clone(),
398            name: id,
399            description,
400        };
401
402        self.inner
403            .services
404            .lock_or_recover()
405            .insert(message.id.clone(), message.clone());
406        let _ = self.send_message(message.to_message());
407    }
408
409    pub fn register_service_with_name(
410        &self,
411        id: impl Into<String>,
412        name: impl Into<String>,
413        description: Option<String>,
414    ) {
415        let message = RegisterServiceMessage {
416            id: id.into(),
417            name: name.into(),
418            description,
419        };
420
421        self.inner
422            .services
423            .lock_or_recover()
424            .insert(message.id.clone(), message.clone());
425        let _ = self.send_message(message.to_message());
426    }
427
428    pub fn register_trigger_type<H>(
429        &self,
430        id: impl Into<String>,
431        description: impl Into<String>,
432        handler: H,
433    ) where
434        H: TriggerHandler + 'static,
435    {
436        let message = RegisterTriggerTypeMessage {
437            id: id.into(),
438            description: description.into(),
439        };
440
441        self.inner.trigger_types.lock_or_recover().insert(
442            message.id.clone(),
443            RemoteTriggerTypeData {
444                message: message.clone(),
445                handler: Arc::new(handler),
446            },
447        );
448
449        let _ = self.send_message(message.to_message());
450    }
451
452    pub fn unregister_trigger_type(&self, id: impl Into<String>) {
453        let id = id.into();
454        self.inner.trigger_types.lock_or_recover().remove(&id);
455    }
456
457    pub fn register_trigger(
458        &self,
459        trigger_type: impl Into<String>,
460        function_id: impl Into<String>,
461        config: impl serde::Serialize,
462    ) -> Result<Trigger, IIIError> {
463        let id = Uuid::new_v4().to_string();
464        let config = serde_json::to_value(config)?;
465        let message = RegisterTriggerMessage {
466            id: id.clone(),
467            trigger_type: trigger_type.into(),
468            function_id: function_id.into(),
469            config,
470        };
471
472        self.inner
473            .triggers
474            .lock_or_recover()
475            .insert(message.id.clone(), message.clone());
476        let _ = self.send_message(message.to_message());
477
478        let iii = self.clone();
479        let trigger_type = message.trigger_type.clone();
480        let unregister_id = message.id.clone();
481        let unregister_fn = Arc::new(move || {
482            let _ = iii.inner.triggers.lock_or_recover().remove(&unregister_id);
483            let msg = UnregisterTriggerMessage {
484                id: unregister_id.clone(),
485                trigger_type: trigger_type.clone(),
486            };
487            let _ = iii.send_message(msg.to_message());
488        });
489
490        Ok(Trigger::new(unregister_fn))
491    }
492
493    pub async fn call(
494        &self,
495        function_id: &str,
496        data: impl serde::Serialize,
497    ) -> Result<Value, IIIError> {
498        let value = serde_json::to_value(data)?;
499        self.call_with_timeout(function_id, value, DEFAULT_TIMEOUT)
500            .await
501    }
502
503    pub async fn call_with_timeout(
504        &self,
505        function_id: &str,
506        data: Value,
507        timeout: Duration,
508    ) -> Result<Value, IIIError> {
509        let invocation_id = Uuid::new_v4();
510        let (tx, rx) = oneshot::channel();
511
512        self.inner
513            .pending
514            .lock_or_recover()
515            .insert(invocation_id, tx);
516
517        let (tp, bg) = inject_trace_headers();
518
519        self.send_message(Message::InvokeFunction {
520            invocation_id: Some(invocation_id),
521            function_id: function_id.to_string(),
522            data,
523            traceparent: tp,
524            baggage: bg,
525        })?;
526
527        match tokio::time::timeout(timeout, rx).await {
528            Ok(Ok(result)) => result,
529            Ok(Err(_)) => Err(IIIError::NotConnected),
530            Err(_) => {
531                self.inner.pending.lock_or_recover().remove(&invocation_id);
532                Err(IIIError::Timeout)
533            }
534        }
535    }
536
537    pub fn call_void<TInput>(&self, function_id: &str, data: TInput) -> Result<(), IIIError>
538    where
539        TInput: Serialize,
540    {
541        let value = serde_json::to_value(data)?;
542
543        let (tp, bg) = inject_trace_headers();
544
545        self.send_message(Message::InvokeFunction {
546            invocation_id: None,
547            function_id: function_id.to_string(),
548            data: value,
549            traceparent: tp,
550            baggage: bg,
551        })
552    }
553
554    /// List all registered functions from the engine
555    pub async fn list_functions(&self) -> Result<Vec<FunctionInfo>, IIIError> {
556        let result = self
557            .call("engine.functions.list", serde_json::json!({}))
558            .await?;
559
560        let functions = result
561            .get("functions")
562            .and_then(|v| serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok())
563            .unwrap_or_default();
564
565        Ok(functions)
566    }
567
568    /// Subscribe to function availability events
569    /// Returns a guard that will unsubscribe when dropped
570    pub fn on_functions_available<F>(&self, callback: F) -> FunctionsAvailableGuard
571    where
572        F: Fn(Vec<FunctionInfo>) + Send + Sync + 'static,
573    {
574        let callback = Arc::new(callback);
575        let callback_id = self
576            .inner
577            .functions_available_callback_counter
578            .fetch_add(1, Ordering::Relaxed);
579
580        self.inner
581            .functions_available_callbacks
582            .lock_or_recover()
583            .insert(callback_id, callback);
584
585        // Set up trigger if not already done
586        let mut trigger_guard = self.inner.functions_available_trigger.lock_or_recover();
587        if trigger_guard.is_none() {
588            // Get or create function path (reuse existing if trigger registration previously failed)
589            let function_id = {
590                let mut path_guard = self.inner.functions_available_function_id.lock_or_recover();
591                if path_guard.is_none() {
592                    let path = format!("iii.on_functions_available.{}", Uuid::new_v4());
593                    *path_guard = Some(path.clone());
594                    path
595                } else {
596                    path_guard.clone().unwrap()
597                }
598            };
599
600            // Register handler function only if it doesn't already exist
601            let function_exists = self
602                .inner
603                .functions
604                .lock_or_recover()
605                .contains_key(&function_id);
606            if !function_exists {
607                let iii = self.clone();
608                self.register_function(function_id.clone(), move |input: Value| {
609                    let iii = iii.clone();
610                    async move {
611                        // Extract functions from trigger payload
612                        let functions = input
613                            .get("functions")
614                            .and_then(|v| {
615                                serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok()
616                            })
617                            .unwrap_or_default();
618
619                        let callbacks = iii.inner.functions_available_callbacks.lock_or_recover();
620                        for cb in callbacks.values() {
621                            cb(functions.clone());
622                        }
623                        Ok(Value::Null)
624                    }
625                });
626            }
627
628            // Register trigger
629            match self.register_trigger(
630                "engine::functions-available",
631                function_id,
632                serde_json::json!({}),
633            ) {
634                Ok(trigger) => {
635                    *trigger_guard = Some(trigger);
636                }
637                Err(err) => {
638                    tracing::warn!(error = %err, "Failed to register functions_available trigger");
639                }
640            }
641        }
642
643        FunctionsAvailableGuard {
644            iii: self.clone(),
645            callback_id,
646        }
647    }
648
649    /// List all connected workers from the engine
650    pub async fn list_workers(&self) -> Result<Vec<WorkerInfo>, IIIError> {
651        let result = self
652            .call("engine.workers.list", serde_json::json!({}))
653            .await?;
654
655        let workers = result
656            .get("workers")
657            .and_then(|v| serde_json::from_value::<Vec<WorkerInfo>>(v.clone()).ok())
658            .unwrap_or_default();
659
660        Ok(workers)
661    }
662
663    /// List all registered triggers from the engine
664    pub async fn list_triggers(&self) -> Result<Vec<TriggerInfo>, IIIError> {
665        let result = self
666            .call("engine.triggers.list", serde_json::json!({}))
667            .await?;
668
669        let triggers = result
670            .get("triggers")
671            .and_then(|v| serde_json::from_value::<Vec<TriggerInfo>>(v.clone()).ok())
672            .unwrap_or_default();
673
674        Ok(triggers)
675    }
676
677    /// Register this worker's metadata with the engine (called automatically on connect)
678    fn register_worker_metadata(&self) {
679        if let Some(metadata) = self.inner.worker_metadata.lock_or_recover().clone() {
680            let _ = self.call_void("engine.workers.register", metadata);
681        }
682    }
683
684    fn send_message(&self, message: Message) -> Result<(), IIIError> {
685        if !self.inner.running.load(Ordering::SeqCst) {
686            return Ok(());
687        }
688
689        self.inner
690            .outbound
691            .send(Outbound::Message(message))
692            .map_err(|_| IIIError::NotConnected)
693    }
694
695    async fn run_connection(&self, mut rx: mpsc::UnboundedReceiver<Outbound>) {
696        let mut queue: Vec<Message> = Vec::new();
697
698        while self.inner.running.load(Ordering::SeqCst) {
699            match connect_async(&self.inner.address).await {
700                Ok((stream, _)) => {
701                    tracing::info!(address = %self.inner.address, "iii connected");
702                    let (mut ws_tx, mut ws_rx) = stream.split();
703
704                    queue.extend(self.collect_registrations());
705                    Self::dedupe_registrations(&mut queue);
706                    if let Err(err) = self.flush_queue(&mut ws_tx, &mut queue).await {
707                        tracing::warn!(error = %err, "failed to flush queue");
708                        sleep(Duration::from_secs(2)).await;
709                        continue;
710                    }
711
712                    // Auto-register worker metadata on connect (like Node SDK)
713                    self.register_worker_metadata();
714
715                    let mut should_reconnect = false;
716
717                    while self.inner.running.load(Ordering::SeqCst) && !should_reconnect {
718                        tokio::select! {
719                            outgoing = rx.recv() => {
720                                match outgoing {
721                                    Some(Outbound::Message(message)) => {
722                                        if let Err(err) = self.send_ws(&mut ws_tx, &message).await {
723                                            tracing::warn!(error = %err, "send failed; reconnecting");
724                                            queue.push(message);
725                                            should_reconnect = true;
726                                        }
727                                    }
728                                    Some(Outbound::Shutdown) => {
729                                        self.inner.running.store(false, Ordering::SeqCst);
730                                        return;
731                                    }
732                                    None => {
733                                        self.inner.running.store(false, Ordering::SeqCst);
734                                        return;
735                                    }
736                                }
737                            }
738                            incoming = ws_rx.next() => {
739                                match incoming {
740                                    Some(Ok(frame)) => {
741                                        if let Err(err) = self.handle_frame(frame) {
742                                            tracing::warn!(error = %err, "failed to handle frame");
743                                        }
744                                    }
745                                    Some(Err(err)) => {
746                                        tracing::warn!(error = %err, "websocket receive error");
747                                        should_reconnect = true;
748                                    }
749                                    None => {
750                                        should_reconnect = true;
751                                    }
752                                }
753                            }
754                        }
755                    }
756                }
757                Err(err) => {
758                    tracing::warn!(error = %err, "failed to connect; retrying");
759                }
760            }
761
762            if self.inner.running.load(Ordering::SeqCst) {
763                sleep(Duration::from_secs(2)).await;
764            }
765        }
766    }
767
768    fn collect_registrations(&self) -> Vec<Message> {
769        let mut messages = Vec::new();
770
771        for trigger_type in self.inner.trigger_types.lock_or_recover().values() {
772            messages.push(trigger_type.message.to_message());
773        }
774
775        for service in self.inner.services.lock_or_recover().values() {
776            messages.push(service.to_message());
777        }
778
779        for function in self.inner.functions.lock_or_recover().values() {
780            messages.push(function.message.to_message());
781        }
782
783        for trigger in self.inner.triggers.lock_or_recover().values() {
784            messages.push(trigger.to_message());
785        }
786
787        messages
788    }
789
790    fn dedupe_registrations(queue: &mut Vec<Message>) {
791        let mut seen = HashSet::new();
792        let mut deduped_rev = Vec::with_capacity(queue.len());
793
794        for message in queue.iter().rev() {
795            let key = match message {
796                Message::RegisterTriggerType { id, .. } => format!("trigger_type:{id}"),
797                Message::RegisterTrigger { id, .. } => format!("trigger:{id}"),
798                Message::RegisterFunction { id, .. } => {
799                    format!("function:{id}")
800                }
801                Message::RegisterService { id, .. } => format!("service:{id}"),
802                _ => {
803                    deduped_rev.push(message.clone());
804                    continue;
805                }
806            };
807
808            if seen.insert(key) {
809                deduped_rev.push(message.clone());
810            }
811        }
812
813        deduped_rev.reverse();
814        *queue = deduped_rev;
815    }
816
817    async fn flush_queue(
818        &self,
819        ws_tx: &mut WsTx,
820        queue: &mut Vec<Message>,
821    ) -> Result<(), IIIError> {
822        let mut drained = Vec::new();
823        std::mem::swap(queue, &mut drained);
824
825        let mut iter = drained.into_iter();
826        while let Some(message) = iter.next() {
827            if let Err(err) = self.send_ws(ws_tx, &message).await {
828                queue.push(message);
829                queue.extend(iter);
830                return Err(err);
831            }
832        }
833
834        Ok(())
835    }
836
837    async fn send_ws(&self, ws_tx: &mut WsTx, message: &Message) -> Result<(), IIIError> {
838        let payload = serde_json::to_string(message)?;
839        ws_tx.send(WsMessage::Text(payload.into())).await?;
840        Ok(())
841    }
842
843    fn handle_frame(&self, frame: WsMessage) -> Result<(), IIIError> {
844        match frame {
845            WsMessage::Text(text) => self.handle_message(&text),
846            WsMessage::Binary(bytes) => {
847                let text = String::from_utf8_lossy(&bytes).to_string();
848                self.handle_message(&text)
849            }
850            _ => Ok(()),
851        }
852    }
853
854    fn handle_message(&self, payload: &str) -> Result<(), IIIError> {
855        let message: Message = serde_json::from_str(payload)?;
856
857        match message {
858            Message::InvocationResult {
859                invocation_id,
860                result,
861                error,
862                ..
863            } => {
864                self.handle_invocation_result(invocation_id, result, error);
865            }
866            Message::InvokeFunction {
867                invocation_id,
868                function_id,
869                data,
870                traceparent,
871                baggage,
872            } => {
873                self.handle_invoke_function(invocation_id, function_id, data, traceparent, baggage);
874            }
875            Message::RegisterTrigger {
876                id,
877                trigger_type,
878                function_id,
879                config,
880            } => {
881                self.handle_register_trigger(id, trigger_type, function_id, config);
882            }
883            Message::Ping => {
884                let _ = self.send_message(Message::Pong);
885            }
886            Message::WorkerRegistered { worker_id } => {
887                tracing::debug!(worker_id = %worker_id, "Worker registered");
888            }
889            _ => {}
890        }
891
892        Ok(())
893    }
894
895    fn handle_invocation_result(
896        &self,
897        invocation_id: Uuid,
898        result: Option<Value>,
899        error: Option<ErrorBody>,
900    ) {
901        let sender = self.inner.pending.lock_or_recover().remove(&invocation_id);
902        if let Some(sender) = sender {
903            let result = match error {
904                Some(error) => Err(IIIError::Remote {
905                    code: error.code,
906                    message: error.message,
907                }),
908                None => Ok(result.unwrap_or(Value::Null)),
909            };
910            let _ = sender.send(result);
911        }
912    }
913
914    fn handle_invoke_function(
915        &self,
916        invocation_id: Option<Uuid>,
917        function_id: String,
918        data: Value,
919        traceparent: Option<String>,
920        baggage: Option<String>,
921    ) {
922        tracing::debug!(function_id = %function_id, traceparent = ?traceparent, baggage = ?baggage, "Invoking function");
923
924        let handler = self
925            .inner
926            .functions
927            .lock_or_recover()
928            .get(&function_id)
929            .map(|data| data.handler.clone());
930
931        let Some(handler) = handler else {
932            tracing::warn!(function_id = %function_id, "Invocation: Function not found");
933
934            if let Some(invocation_id) = invocation_id {
935                let (resp_tp, resp_bg) = inject_trace_headers();
936
937                let error = ErrorBody {
938                    code: "function_not_found".to_string(),
939                    message: "Function not found".to_string(),
940                };
941                let result = self.send_message(Message::InvocationResult {
942                    invocation_id,
943                    function_id,
944                    result: None,
945                    error: Some(error),
946                    traceparent: resp_tp,
947                    baggage: resp_bg,
948                });
949
950                if let Err(err) = result {
951                    tracing::warn!(error = %err, "error sending invocation result");
952                }
953            }
954            return;
955        };
956
957        let iii = self.clone();
958
959        tokio::spawn(async move {
960            // Extract incoming trace context and create a span for this invocation.
961            // This ensures the handler and any outbound calls it makes (e.g.
962            // invoke_function_with_timeout) are linked as children of the caller's trace.
963            // We use FutureExt::with_context() instead of cx.attach() because
964            // ContextGuard is !Send and can't be held across .await in tokio::spawn.
965            #[cfg(feature = "otel")]
966            let otel_cx = {
967                use crate::telemetry::context::extract_context;
968                use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer};
969
970                let parent_cx = extract_context(traceparent.as_deref(), baggage.as_deref());
971                let tracer = opentelemetry::global::tracer("iii-rust-sdk");
972                let span = tracer
973                    .span_builder(format!("invoke {}", function_id))
974                    .with_kind(SpanKind::Server)
975                    .start_with_context(&tracer, &parent_cx);
976                parent_cx.with_span(span)
977            };
978
979            #[cfg(feature = "otel")]
980            let result = {
981                use opentelemetry::trace::FutureExt as OtelFutureExt;
982                handler(data).with_context(otel_cx.clone()).await
983            };
984
985            #[cfg(not(feature = "otel"))]
986            let result = handler(data).await;
987
988            // Record span status based on result
989            #[cfg(feature = "otel")]
990            {
991                use opentelemetry::trace::{Status, TraceContextExt};
992                let span = otel_cx.span();
993                match &result {
994                    Ok(_) => span.set_status(Status::Ok),
995                    Err(err) => span.set_status(Status::error(err.to_string())),
996                }
997            }
998
999            if let Some(invocation_id) = invocation_id {
1000                // Inject trace context from our span into the response.
1001                // We briefly attach the otel context (no .await crossing)
1002                // so inject_traceparent/inject_baggage can read it.
1003                #[cfg(feature = "otel")]
1004                let (resp_tp, resp_bg) = {
1005                    let _guard = otel_cx.attach();
1006                    inject_trace_headers()
1007                };
1008                #[cfg(not(feature = "otel"))]
1009                let (resp_tp, resp_bg) = inject_trace_headers();
1010
1011                let message = match result {
1012                    Ok(value) => Message::InvocationResult {
1013                        invocation_id,
1014                        function_id,
1015                        result: Some(value),
1016                        error: None,
1017                        traceparent: resp_tp,
1018                        baggage: resp_bg,
1019                    },
1020                    Err(err) => Message::InvocationResult {
1021                        invocation_id,
1022                        function_id,
1023                        result: None,
1024                        error: Some(ErrorBody {
1025                            code: "invocation_failed".to_string(),
1026                            message: err.to_string(),
1027                        }),
1028                        traceparent: resp_tp,
1029                        baggage: resp_bg,
1030                    },
1031                };
1032
1033                let _ = iii.send_message(message);
1034            } else if let Err(err) = result {
1035                tracing::warn!(error = %err, "error handling async invocation");
1036            }
1037        });
1038    }
1039
1040    fn handle_register_trigger(
1041        &self,
1042        id: String,
1043        trigger_type: String,
1044        function_id: String,
1045        config: Value,
1046    ) {
1047        let handler = self
1048            .inner
1049            .trigger_types
1050            .lock_or_recover()
1051            .get(&trigger_type)
1052            .map(|data| data.handler.clone());
1053
1054        let iii = self.clone();
1055
1056        tokio::spawn(async move {
1057            let message = if let Some(handler) = handler {
1058                let config = TriggerConfig {
1059                    id: id.clone(),
1060                    function_id: function_id.clone(),
1061                    config,
1062                };
1063
1064                match handler.register_trigger(config).await {
1065                    Ok(()) => Message::TriggerRegistrationResult {
1066                        id,
1067                        trigger_type,
1068                        function_id,
1069                        error: None,
1070                    },
1071                    Err(err) => Message::TriggerRegistrationResult {
1072                        id,
1073                        trigger_type,
1074                        function_id,
1075                        error: Some(ErrorBody {
1076                            code: "trigger_registration_failed".to_string(),
1077                            message: err.to_string(),
1078                        }),
1079                    },
1080                }
1081            } else {
1082                Message::TriggerRegistrationResult {
1083                    id,
1084                    trigger_type,
1085                    function_id,
1086                    error: Some(ErrorBody {
1087                        code: "trigger_type_not_found".to_string(),
1088                        message: "Trigger type not found".to_string(),
1089                    }),
1090                }
1091            };
1092
1093            let _ = iii.send_message(message);
1094        });
1095    }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100    use serde_json::json;
1101
1102    use super::*;
1103
1104    #[test]
1105    fn register_trigger_unregister_removes_entry() {
1106        let iii = III::new("ws://localhost:1234");
1107        let trigger = iii
1108            .register_trigger("demo", "functions.echo", json!({ "foo": "bar" }))
1109            .unwrap();
1110
1111        assert_eq!(iii.inner.triggers.lock().unwrap().len(), 1);
1112
1113        trigger.unregister();
1114
1115        assert_eq!(iii.inner.triggers.lock().unwrap().len(), 0);
1116    }
1117
1118    #[tokio::test]
1119    async fn invoke_function_times_out_and_clears_pending() {
1120        let iii = III::new("ws://localhost:1234");
1121        let result = iii
1122            .call_with_timeout(
1123                "functions.echo",
1124                json!({ "a": 1 }),
1125                Duration::from_millis(10),
1126            )
1127            .await;
1128
1129        assert!(matches!(result, Err(IIIError::Timeout)));
1130        assert!(iii.inner.pending.lock().unwrap().is_empty());
1131    }
1132}