1use std::{
2 collections::{HashMap, HashSet},
3 sync::{
4 Arc, Mutex, MutexGuard,
5 atomic::{AtomicBool, AtomicUsize, Ordering},
6 },
7 time::Duration,
8};
9
10trait MutexExt<T> {
13 fn lock_or_recover(&self) -> MutexGuard<'_, T>;
14}
15
16impl<T> MutexExt<T> for Mutex<T> {
17 fn lock_or_recover(&self) -> MutexGuard<'_, T> {
18 self.lock().unwrap_or_else(|e| e.into_inner())
19 }
20}
21
22use futures_util::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use tokio::{
26 sync::{mpsc, oneshot},
27 time::sleep,
28};
29use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
30use uuid::Uuid;
31
32const SDK_VERSION: &str = env!("CARGO_PKG_VERSION");
33
34use crate::{
35 context::{Context, with_context},
36 error::IIIError,
37 logger::{Logger, LoggerInvoker},
38 protocol::{
39 ErrorBody, Message, RegisterFunctionMessage, RegisterServiceMessage,
40 RegisterTriggerMessage, RegisterTriggerTypeMessage, UnregisterTriggerMessage,
41 },
42 triggers::{Trigger, TriggerConfig, TriggerHandler},
43 types::{RemoteFunctionData, RemoteFunctionHandler, RemoteTriggerTypeData},
44};
45
46#[cfg(feature = "otel")]
47use crate::telemetry;
48#[cfg(feature = "otel")]
49use crate::telemetry::types::OtelConfig;
50
51const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct WorkerInfo {
56 pub id: String,
57 pub name: Option<String>,
58 pub runtime: Option<String>,
59 pub version: Option<String>,
60 pub os: Option<String>,
61 pub ip_address: Option<String>,
62 pub status: String,
63 pub connected_at_ms: u64,
64 pub function_count: usize,
65 pub functions: Vec<String>,
66 pub active_invocations: usize,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct FunctionInfo {
72 pub function_id: String,
73 pub description: Option<String>,
74 pub request_format: Option<Value>,
75 pub response_format: Option<Value>,
76 pub metadata: Option<Value>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TriggerInfo {
82 pub id: String,
83 pub trigger_type: String,
84 pub function_id: String,
85 pub config: Value,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct WorkerMetadata {
91 pub runtime: String,
92 pub version: String,
93 pub name: String,
94 pub os: String,
95}
96
97impl Default for WorkerMetadata {
98 fn default() -> Self {
99 let hostname = hostname::get()
100 .map(|h| h.to_string_lossy().to_string())
101 .unwrap_or_else(|_| "unknown".to_string());
102 let pid = std::process::id();
103 let os_info = format!(
104 "{} {} ({})",
105 std::env::consts::OS,
106 std::env::consts::ARCH,
107 std::env::consts::FAMILY
108 );
109
110 Self {
111 runtime: "rust".to_string(),
112 version: SDK_VERSION.to_string(),
113 name: format!("{}:{}", hostname, pid),
114 os: os_info,
115 }
116 }
117}
118
119enum Outbound {
120 Message(Message),
121 Shutdown,
122}
123
124type PendingInvocation = oneshot::Sender<Result<Value, IIIError>>;
125
126type WsTx = futures_util::stream::SplitSink<
128 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
129 WsMessage,
130>;
131
132#[cfg(feature = "otel")]
135fn inject_trace_headers() -> (Option<String>, Option<String>) {
136 use crate::telemetry::context;
137 (context::inject_traceparent(), context::inject_baggage())
138}
139
140#[cfg(not(feature = "otel"))]
141fn inject_trace_headers() -> (Option<String>, Option<String>) {
142 (None, None)
143}
144
145pub type FunctionsAvailableCallback = Arc<dyn Fn(Vec<FunctionInfo>) + Send + Sync>;
147
148struct IIIInner {
149 address: String,
150 outbound: mpsc::UnboundedSender<Outbound>,
151 receiver: Mutex<Option<mpsc::UnboundedReceiver<Outbound>>>,
152 running: AtomicBool,
153 started: AtomicBool,
154 pending: Mutex<HashMap<Uuid, PendingInvocation>>,
155 functions: Mutex<HashMap<String, RemoteFunctionData>>,
156 trigger_types: Mutex<HashMap<String, RemoteTriggerTypeData>>,
157 triggers: Mutex<HashMap<String, RegisterTriggerMessage>>,
158 services: Mutex<HashMap<String, RegisterServiceMessage>>,
159 worker_metadata: Mutex<Option<WorkerMetadata>>,
160 functions_available_callbacks: Mutex<HashMap<usize, FunctionsAvailableCallback>>,
161 functions_available_callback_counter: AtomicUsize,
162 functions_available_function_id: Mutex<Option<String>>,
163 functions_available_trigger: Mutex<Option<Trigger>>,
164 #[cfg(feature = "otel")]
165 otel_config: Mutex<Option<OtelConfig>>,
166}
167
168#[derive(Clone)]
169pub struct III {
170 inner: Arc<IIIInner>,
171}
172
173pub struct FunctionsAvailableGuard {
175 iii: III,
176 callback_id: usize,
177}
178
179impl Drop for FunctionsAvailableGuard {
180 fn drop(&mut self) {
181 let mut callbacks = self
182 .iii
183 .inner
184 .functions_available_callbacks
185 .lock_or_recover();
186 callbacks.remove(&self.callback_id);
187
188 if callbacks.is_empty() {
189 let mut trigger = self.iii.inner.functions_available_trigger.lock_or_recover();
190 if let Some(trigger) = trigger.take() {
191 trigger.unregister();
192 }
193 }
194 }
195}
196
197impl III {
198 pub fn new(address: &str) -> Self {
200 Self::with_metadata(address, WorkerMetadata::default())
201 }
202
203 pub fn with_metadata(address: &str, metadata: WorkerMetadata) -> Self {
205 let (tx, rx) = mpsc::unbounded_channel();
206 let inner = IIIInner {
207 address: address.into(),
208 outbound: tx,
209 receiver: Mutex::new(Some(rx)),
210 running: AtomicBool::new(false),
211 started: AtomicBool::new(false),
212 pending: Mutex::new(HashMap::new()),
213 functions: Mutex::new(HashMap::new()),
214 trigger_types: Mutex::new(HashMap::new()),
215 triggers: Mutex::new(HashMap::new()),
216 services: Mutex::new(HashMap::new()),
217 worker_metadata: Mutex::new(Some(metadata)),
218 functions_available_callbacks: Mutex::new(HashMap::new()),
219 functions_available_callback_counter: AtomicUsize::new(0),
220 functions_available_function_id: Mutex::new(None),
221 functions_available_trigger: Mutex::new(None),
222 #[cfg(feature = "otel")]
223 otel_config: Mutex::new(None),
224 };
225 Self {
226 inner: Arc::new(inner),
227 }
228 }
229
230 pub fn set_metadata(&self, metadata: WorkerMetadata) {
232 *self.inner.worker_metadata.lock_or_recover() = Some(metadata);
233 }
234
235 #[cfg(feature = "otel")]
237 pub fn set_otel_config(&self, config: OtelConfig) {
238 *self.inner.otel_config.lock_or_recover() = Some(config);
239 }
240
241 pub async fn connect(&self) -> Result<(), IIIError> {
242 if self.inner.started.swap(true, Ordering::SeqCst) {
243 return Ok(());
244 }
245
246 let receiver = self.inner.receiver.lock_or_recover().take();
247 let Some(rx) = receiver else {
248 return Ok(());
249 };
250
251 let iii = self.clone();
252
253 tokio::spawn(async move {
254 iii.inner.running.store(true, Ordering::SeqCst);
255 iii.run_connection(rx).await;
256 });
257
258 #[cfg(feature = "otel")]
264 {
265 let config = self.inner.otel_config.lock_or_recover().take();
266 if let Some(mut config) = config {
267 if config.engine_ws_url.is_none() {
269 config.engine_ws_url = Some(self.inner.address.clone());
270 }
271 telemetry::init_otel(config).await;
272 }
273 }
274
275 Ok(())
276 }
277
278 #[deprecated(note = "Use shutdown_async() for guaranteed telemetry flush")]
285 pub fn shutdown(&self) {
286 self.inner.running.store(false, Ordering::SeqCst);
287 let _ = self.inner.outbound.send(Outbound::Shutdown);
288
289 #[cfg(feature = "otel")]
291 {
292 tracing::warn!(
293 "shutdown() does not await telemetry flush; use shutdown_async() instead"
294 );
295 tokio::spawn(async {
296 telemetry::shutdown_otel().await;
297 });
298 }
299 }
300
301 pub async fn shutdown_async(&self) {
308 self.inner.running.store(false, Ordering::SeqCst);
309 let _ = self.inner.outbound.send(Outbound::Shutdown);
310
311 #[cfg(feature = "otel")]
312 telemetry::shutdown_otel().await;
313 }
314
315 pub fn register_function<F, Fut>(&self, id: impl Into<String>, handler: F)
316 where
317 F: Fn(Value) -> Fut + Send + Sync + 'static,
318 Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
319 {
320 let message = RegisterFunctionMessage {
321 id: id.into(),
322 description: None,
323 request_format: None,
324 response_format: None,
325 metadata: None,
326 };
327
328 self.register_function_with(message, handler);
329 }
330
331 pub fn register_function_with_description<F, Fut>(
332 &self,
333 id: impl Into<String>,
334 description: impl Into<String>,
335 handler: F,
336 ) where
337 F: Fn(Value) -> Fut + Send + Sync + 'static,
338 Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
339 {
340 let message = RegisterFunctionMessage {
341 id: id.into(),
342 description: Some(description.into()),
343 request_format: None,
344 response_format: None,
345 metadata: None,
346 };
347
348 self.register_function_with(message, handler);
349 }
350
351 pub fn register_function_with<F, Fut>(&self, message: RegisterFunctionMessage, handler: F)
352 where
353 F: Fn(Value) -> Fut + Send + Sync + 'static,
354 Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
355 {
356 let function_id = message.id.clone();
357 let iii = self.clone();
358
359 let user_handler = Arc::new(move |input: Value| Box::pin(handler(input)));
360
361 let wrapped_handler: RemoteFunctionHandler = Arc::new(move |input: Value| {
362 let function_id = function_id.clone();
363 let iii = iii.clone();
364 let user_handler = user_handler.clone();
365
366 Box::pin(async move {
367 let invoker: LoggerInvoker = Arc::new(move |path, params| {
368 let _ = iii.call_void(path, params);
369 });
370
371 let logger = Logger::new(
372 Some(invoker),
373 Some(Uuid::new_v4().to_string()),
374 Some(function_id.clone()),
375 );
376 let context = Context { logger };
377
378 with_context(context, || user_handler(input)).await
379 })
380 });
381
382 let data = RemoteFunctionData {
383 message: message.clone(),
384 handler: wrapped_handler,
385 };
386
387 self.inner
388 .functions
389 .lock_or_recover()
390 .insert(message.id.clone(), data);
391 let _ = self.send_message(message.to_message());
392 }
393
394 pub fn register_service(&self, id: impl Into<String>, description: Option<String>) {
395 let id = id.into();
396 let message = RegisterServiceMessage {
397 id: id.clone(),
398 name: id,
399 description,
400 };
401
402 self.inner
403 .services
404 .lock_or_recover()
405 .insert(message.id.clone(), message.clone());
406 let _ = self.send_message(message.to_message());
407 }
408
409 pub fn register_service_with_name(
410 &self,
411 id: impl Into<String>,
412 name: impl Into<String>,
413 description: Option<String>,
414 ) {
415 let message = RegisterServiceMessage {
416 id: id.into(),
417 name: name.into(),
418 description,
419 };
420
421 self.inner
422 .services
423 .lock_or_recover()
424 .insert(message.id.clone(), message.clone());
425 let _ = self.send_message(message.to_message());
426 }
427
428 pub fn register_trigger_type<H>(
429 &self,
430 id: impl Into<String>,
431 description: impl Into<String>,
432 handler: H,
433 ) where
434 H: TriggerHandler + 'static,
435 {
436 let message = RegisterTriggerTypeMessage {
437 id: id.into(),
438 description: description.into(),
439 };
440
441 self.inner.trigger_types.lock_or_recover().insert(
442 message.id.clone(),
443 RemoteTriggerTypeData {
444 message: message.clone(),
445 handler: Arc::new(handler),
446 },
447 );
448
449 let _ = self.send_message(message.to_message());
450 }
451
452 pub fn unregister_trigger_type(&self, id: impl Into<String>) {
453 let id = id.into();
454 self.inner.trigger_types.lock_or_recover().remove(&id);
455 }
456
457 pub fn register_trigger(
458 &self,
459 trigger_type: impl Into<String>,
460 function_id: impl Into<String>,
461 config: impl serde::Serialize,
462 ) -> Result<Trigger, IIIError> {
463 let id = Uuid::new_v4().to_string();
464 let config = serde_json::to_value(config)?;
465 let message = RegisterTriggerMessage {
466 id: id.clone(),
467 trigger_type: trigger_type.into(),
468 function_id: function_id.into(),
469 config,
470 };
471
472 self.inner
473 .triggers
474 .lock_or_recover()
475 .insert(message.id.clone(), message.clone());
476 let _ = self.send_message(message.to_message());
477
478 let iii = self.clone();
479 let trigger_type = message.trigger_type.clone();
480 let unregister_id = message.id.clone();
481 let unregister_fn = Arc::new(move || {
482 let _ = iii.inner.triggers.lock_or_recover().remove(&unregister_id);
483 let msg = UnregisterTriggerMessage {
484 id: unregister_id.clone(),
485 trigger_type: trigger_type.clone(),
486 };
487 let _ = iii.send_message(msg.to_message());
488 });
489
490 Ok(Trigger::new(unregister_fn))
491 }
492
493 pub async fn call(
494 &self,
495 function_id: &str,
496 data: impl serde::Serialize,
497 ) -> Result<Value, IIIError> {
498 let value = serde_json::to_value(data)?;
499 self.call_with_timeout(function_id, value, DEFAULT_TIMEOUT)
500 .await
501 }
502
503 pub async fn call_with_timeout(
504 &self,
505 function_id: &str,
506 data: Value,
507 timeout: Duration,
508 ) -> Result<Value, IIIError> {
509 let invocation_id = Uuid::new_v4();
510 let (tx, rx) = oneshot::channel();
511
512 self.inner
513 .pending
514 .lock_or_recover()
515 .insert(invocation_id, tx);
516
517 let (tp, bg) = inject_trace_headers();
518
519 self.send_message(Message::InvokeFunction {
520 invocation_id: Some(invocation_id),
521 function_id: function_id.to_string(),
522 data,
523 traceparent: tp,
524 baggage: bg,
525 })?;
526
527 match tokio::time::timeout(timeout, rx).await {
528 Ok(Ok(result)) => result,
529 Ok(Err(_)) => Err(IIIError::NotConnected),
530 Err(_) => {
531 self.inner.pending.lock_or_recover().remove(&invocation_id);
532 Err(IIIError::Timeout)
533 }
534 }
535 }
536
537 pub fn call_void<TInput>(&self, function_id: &str, data: TInput) -> Result<(), IIIError>
538 where
539 TInput: Serialize,
540 {
541 let value = serde_json::to_value(data)?;
542
543 let (tp, bg) = inject_trace_headers();
544
545 self.send_message(Message::InvokeFunction {
546 invocation_id: None,
547 function_id: function_id.to_string(),
548 data: value,
549 traceparent: tp,
550 baggage: bg,
551 })
552 }
553
554 pub async fn list_functions(&self) -> Result<Vec<FunctionInfo>, IIIError> {
556 let result = self
557 .call("engine.functions.list", serde_json::json!({}))
558 .await?;
559
560 let functions = result
561 .get("functions")
562 .and_then(|v| serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok())
563 .unwrap_or_default();
564
565 Ok(functions)
566 }
567
568 pub fn on_functions_available<F>(&self, callback: F) -> FunctionsAvailableGuard
571 where
572 F: Fn(Vec<FunctionInfo>) + Send + Sync + 'static,
573 {
574 let callback = Arc::new(callback);
575 let callback_id = self
576 .inner
577 .functions_available_callback_counter
578 .fetch_add(1, Ordering::Relaxed);
579
580 self.inner
581 .functions_available_callbacks
582 .lock_or_recover()
583 .insert(callback_id, callback);
584
585 let mut trigger_guard = self.inner.functions_available_trigger.lock_or_recover();
587 if trigger_guard.is_none() {
588 let function_id = {
590 let mut path_guard = self.inner.functions_available_function_id.lock_or_recover();
591 if path_guard.is_none() {
592 let path = format!("iii.on_functions_available.{}", Uuid::new_v4());
593 *path_guard = Some(path.clone());
594 path
595 } else {
596 path_guard.clone().unwrap()
597 }
598 };
599
600 let function_exists = self
602 .inner
603 .functions
604 .lock_or_recover()
605 .contains_key(&function_id);
606 if !function_exists {
607 let iii = self.clone();
608 self.register_function(function_id.clone(), move |input: Value| {
609 let iii = iii.clone();
610 async move {
611 let functions = input
613 .get("functions")
614 .and_then(|v| {
615 serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok()
616 })
617 .unwrap_or_default();
618
619 let callbacks = iii.inner.functions_available_callbacks.lock_or_recover();
620 for cb in callbacks.values() {
621 cb(functions.clone());
622 }
623 Ok(Value::Null)
624 }
625 });
626 }
627
628 match self.register_trigger(
630 "engine::functions-available",
631 function_id,
632 serde_json::json!({}),
633 ) {
634 Ok(trigger) => {
635 *trigger_guard = Some(trigger);
636 }
637 Err(err) => {
638 tracing::warn!(error = %err, "Failed to register functions_available trigger");
639 }
640 }
641 }
642
643 FunctionsAvailableGuard {
644 iii: self.clone(),
645 callback_id,
646 }
647 }
648
649 pub async fn list_workers(&self) -> Result<Vec<WorkerInfo>, IIIError> {
651 let result = self
652 .call("engine.workers.list", serde_json::json!({}))
653 .await?;
654
655 let workers = result
656 .get("workers")
657 .and_then(|v| serde_json::from_value::<Vec<WorkerInfo>>(v.clone()).ok())
658 .unwrap_or_default();
659
660 Ok(workers)
661 }
662
663 pub async fn list_triggers(&self) -> Result<Vec<TriggerInfo>, IIIError> {
665 let result = self
666 .call("engine.triggers.list", serde_json::json!({}))
667 .await?;
668
669 let triggers = result
670 .get("triggers")
671 .and_then(|v| serde_json::from_value::<Vec<TriggerInfo>>(v.clone()).ok())
672 .unwrap_or_default();
673
674 Ok(triggers)
675 }
676
677 fn register_worker_metadata(&self) {
679 if let Some(metadata) = self.inner.worker_metadata.lock_or_recover().clone() {
680 let _ = self.call_void("engine.workers.register", metadata);
681 }
682 }
683
684 fn send_message(&self, message: Message) -> Result<(), IIIError> {
685 if !self.inner.running.load(Ordering::SeqCst) {
686 return Ok(());
687 }
688
689 self.inner
690 .outbound
691 .send(Outbound::Message(message))
692 .map_err(|_| IIIError::NotConnected)
693 }
694
695 async fn run_connection(&self, mut rx: mpsc::UnboundedReceiver<Outbound>) {
696 let mut queue: Vec<Message> = Vec::new();
697
698 while self.inner.running.load(Ordering::SeqCst) {
699 match connect_async(&self.inner.address).await {
700 Ok((stream, _)) => {
701 tracing::info!(address = %self.inner.address, "iii connected");
702 let (mut ws_tx, mut ws_rx) = stream.split();
703
704 queue.extend(self.collect_registrations());
705 Self::dedupe_registrations(&mut queue);
706 if let Err(err) = self.flush_queue(&mut ws_tx, &mut queue).await {
707 tracing::warn!(error = %err, "failed to flush queue");
708 sleep(Duration::from_secs(2)).await;
709 continue;
710 }
711
712 self.register_worker_metadata();
714
715 let mut should_reconnect = false;
716
717 while self.inner.running.load(Ordering::SeqCst) && !should_reconnect {
718 tokio::select! {
719 outgoing = rx.recv() => {
720 match outgoing {
721 Some(Outbound::Message(message)) => {
722 if let Err(err) = self.send_ws(&mut ws_tx, &message).await {
723 tracing::warn!(error = %err, "send failed; reconnecting");
724 queue.push(message);
725 should_reconnect = true;
726 }
727 }
728 Some(Outbound::Shutdown) => {
729 self.inner.running.store(false, Ordering::SeqCst);
730 return;
731 }
732 None => {
733 self.inner.running.store(false, Ordering::SeqCst);
734 return;
735 }
736 }
737 }
738 incoming = ws_rx.next() => {
739 match incoming {
740 Some(Ok(frame)) => {
741 if let Err(err) = self.handle_frame(frame) {
742 tracing::warn!(error = %err, "failed to handle frame");
743 }
744 }
745 Some(Err(err)) => {
746 tracing::warn!(error = %err, "websocket receive error");
747 should_reconnect = true;
748 }
749 None => {
750 should_reconnect = true;
751 }
752 }
753 }
754 }
755 }
756 }
757 Err(err) => {
758 tracing::warn!(error = %err, "failed to connect; retrying");
759 }
760 }
761
762 if self.inner.running.load(Ordering::SeqCst) {
763 sleep(Duration::from_secs(2)).await;
764 }
765 }
766 }
767
768 fn collect_registrations(&self) -> Vec<Message> {
769 let mut messages = Vec::new();
770
771 for trigger_type in self.inner.trigger_types.lock_or_recover().values() {
772 messages.push(trigger_type.message.to_message());
773 }
774
775 for service in self.inner.services.lock_or_recover().values() {
776 messages.push(service.to_message());
777 }
778
779 for function in self.inner.functions.lock_or_recover().values() {
780 messages.push(function.message.to_message());
781 }
782
783 for trigger in self.inner.triggers.lock_or_recover().values() {
784 messages.push(trigger.to_message());
785 }
786
787 messages
788 }
789
790 fn dedupe_registrations(queue: &mut Vec<Message>) {
791 let mut seen = HashSet::new();
792 let mut deduped_rev = Vec::with_capacity(queue.len());
793
794 for message in queue.iter().rev() {
795 let key = match message {
796 Message::RegisterTriggerType { id, .. } => format!("trigger_type:{id}"),
797 Message::RegisterTrigger { id, .. } => format!("trigger:{id}"),
798 Message::RegisterFunction { id, .. } => {
799 format!("function:{id}")
800 }
801 Message::RegisterService { id, .. } => format!("service:{id}"),
802 _ => {
803 deduped_rev.push(message.clone());
804 continue;
805 }
806 };
807
808 if seen.insert(key) {
809 deduped_rev.push(message.clone());
810 }
811 }
812
813 deduped_rev.reverse();
814 *queue = deduped_rev;
815 }
816
817 async fn flush_queue(
818 &self,
819 ws_tx: &mut WsTx,
820 queue: &mut Vec<Message>,
821 ) -> Result<(), IIIError> {
822 let mut drained = Vec::new();
823 std::mem::swap(queue, &mut drained);
824
825 let mut iter = drained.into_iter();
826 while let Some(message) = iter.next() {
827 if let Err(err) = self.send_ws(ws_tx, &message).await {
828 queue.push(message);
829 queue.extend(iter);
830 return Err(err);
831 }
832 }
833
834 Ok(())
835 }
836
837 async fn send_ws(&self, ws_tx: &mut WsTx, message: &Message) -> Result<(), IIIError> {
838 let payload = serde_json::to_string(message)?;
839 ws_tx.send(WsMessage::Text(payload.into())).await?;
840 Ok(())
841 }
842
843 fn handle_frame(&self, frame: WsMessage) -> Result<(), IIIError> {
844 match frame {
845 WsMessage::Text(text) => self.handle_message(&text),
846 WsMessage::Binary(bytes) => {
847 let text = String::from_utf8_lossy(&bytes).to_string();
848 self.handle_message(&text)
849 }
850 _ => Ok(()),
851 }
852 }
853
854 fn handle_message(&self, payload: &str) -> Result<(), IIIError> {
855 let message: Message = serde_json::from_str(payload)?;
856
857 match message {
858 Message::InvocationResult {
859 invocation_id,
860 result,
861 error,
862 ..
863 } => {
864 self.handle_invocation_result(invocation_id, result, error);
865 }
866 Message::InvokeFunction {
867 invocation_id,
868 function_id,
869 data,
870 traceparent,
871 baggage,
872 } => {
873 self.handle_invoke_function(invocation_id, function_id, data, traceparent, baggage);
874 }
875 Message::RegisterTrigger {
876 id,
877 trigger_type,
878 function_id,
879 config,
880 } => {
881 self.handle_register_trigger(id, trigger_type, function_id, config);
882 }
883 Message::Ping => {
884 let _ = self.send_message(Message::Pong);
885 }
886 Message::WorkerRegistered { worker_id } => {
887 tracing::debug!(worker_id = %worker_id, "Worker registered");
888 }
889 _ => {}
890 }
891
892 Ok(())
893 }
894
895 fn handle_invocation_result(
896 &self,
897 invocation_id: Uuid,
898 result: Option<Value>,
899 error: Option<ErrorBody>,
900 ) {
901 let sender = self.inner.pending.lock_or_recover().remove(&invocation_id);
902 if let Some(sender) = sender {
903 let result = match error {
904 Some(error) => Err(IIIError::Remote {
905 code: error.code,
906 message: error.message,
907 }),
908 None => Ok(result.unwrap_or(Value::Null)),
909 };
910 let _ = sender.send(result);
911 }
912 }
913
914 fn handle_invoke_function(
915 &self,
916 invocation_id: Option<Uuid>,
917 function_id: String,
918 data: Value,
919 traceparent: Option<String>,
920 baggage: Option<String>,
921 ) {
922 tracing::debug!(function_id = %function_id, traceparent = ?traceparent, baggage = ?baggage, "Invoking function");
923
924 let handler = self
925 .inner
926 .functions
927 .lock_or_recover()
928 .get(&function_id)
929 .map(|data| data.handler.clone());
930
931 let Some(handler) = handler else {
932 tracing::warn!(function_id = %function_id, "Invocation: Function not found");
933
934 if let Some(invocation_id) = invocation_id {
935 let (resp_tp, resp_bg) = inject_trace_headers();
936
937 let error = ErrorBody {
938 code: "function_not_found".to_string(),
939 message: "Function not found".to_string(),
940 };
941 let result = self.send_message(Message::InvocationResult {
942 invocation_id,
943 function_id,
944 result: None,
945 error: Some(error),
946 traceparent: resp_tp,
947 baggage: resp_bg,
948 });
949
950 if let Err(err) = result {
951 tracing::warn!(error = %err, "error sending invocation result");
952 }
953 }
954 return;
955 };
956
957 let iii = self.clone();
958
959 tokio::spawn(async move {
960 #[cfg(feature = "otel")]
966 let otel_cx = {
967 use crate::telemetry::context::extract_context;
968 use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer};
969
970 let parent_cx = extract_context(traceparent.as_deref(), baggage.as_deref());
971 let tracer = opentelemetry::global::tracer("iii-rust-sdk");
972 let span = tracer
973 .span_builder(format!("invoke {}", function_id))
974 .with_kind(SpanKind::Server)
975 .start_with_context(&tracer, &parent_cx);
976 parent_cx.with_span(span)
977 };
978
979 #[cfg(feature = "otel")]
980 let result = {
981 use opentelemetry::trace::FutureExt as OtelFutureExt;
982 handler(data).with_context(otel_cx.clone()).await
983 };
984
985 #[cfg(not(feature = "otel"))]
986 let result = handler(data).await;
987
988 #[cfg(feature = "otel")]
990 {
991 use opentelemetry::trace::{Status, TraceContextExt};
992 let span = otel_cx.span();
993 match &result {
994 Ok(_) => span.set_status(Status::Ok),
995 Err(err) => span.set_status(Status::error(err.to_string())),
996 }
997 }
998
999 if let Some(invocation_id) = invocation_id {
1000 #[cfg(feature = "otel")]
1004 let (resp_tp, resp_bg) = {
1005 let _guard = otel_cx.attach();
1006 inject_trace_headers()
1007 };
1008 #[cfg(not(feature = "otel"))]
1009 let (resp_tp, resp_bg) = inject_trace_headers();
1010
1011 let message = match result {
1012 Ok(value) => Message::InvocationResult {
1013 invocation_id,
1014 function_id,
1015 result: Some(value),
1016 error: None,
1017 traceparent: resp_tp,
1018 baggage: resp_bg,
1019 },
1020 Err(err) => Message::InvocationResult {
1021 invocation_id,
1022 function_id,
1023 result: None,
1024 error: Some(ErrorBody {
1025 code: "invocation_failed".to_string(),
1026 message: err.to_string(),
1027 }),
1028 traceparent: resp_tp,
1029 baggage: resp_bg,
1030 },
1031 };
1032
1033 let _ = iii.send_message(message);
1034 } else if let Err(err) = result {
1035 tracing::warn!(error = %err, "error handling async invocation");
1036 }
1037 });
1038 }
1039
1040 fn handle_register_trigger(
1041 &self,
1042 id: String,
1043 trigger_type: String,
1044 function_id: String,
1045 config: Value,
1046 ) {
1047 let handler = self
1048 .inner
1049 .trigger_types
1050 .lock_or_recover()
1051 .get(&trigger_type)
1052 .map(|data| data.handler.clone());
1053
1054 let iii = self.clone();
1055
1056 tokio::spawn(async move {
1057 let message = if let Some(handler) = handler {
1058 let config = TriggerConfig {
1059 id: id.clone(),
1060 function_id: function_id.clone(),
1061 config,
1062 };
1063
1064 match handler.register_trigger(config).await {
1065 Ok(()) => Message::TriggerRegistrationResult {
1066 id,
1067 trigger_type,
1068 function_id,
1069 error: None,
1070 },
1071 Err(err) => Message::TriggerRegistrationResult {
1072 id,
1073 trigger_type,
1074 function_id,
1075 error: Some(ErrorBody {
1076 code: "trigger_registration_failed".to_string(),
1077 message: err.to_string(),
1078 }),
1079 },
1080 }
1081 } else {
1082 Message::TriggerRegistrationResult {
1083 id,
1084 trigger_type,
1085 function_id,
1086 error: Some(ErrorBody {
1087 code: "trigger_type_not_found".to_string(),
1088 message: "Trigger type not found".to_string(),
1089 }),
1090 }
1091 };
1092
1093 let _ = iii.send_message(message);
1094 });
1095 }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100 use serde_json::json;
1101
1102 use super::*;
1103
1104 #[test]
1105 fn register_trigger_unregister_removes_entry() {
1106 let iii = III::new("ws://localhost:1234");
1107 let trigger = iii
1108 .register_trigger("demo", "functions.echo", json!({ "foo": "bar" }))
1109 .unwrap();
1110
1111 assert_eq!(iii.inner.triggers.lock().unwrap().len(), 1);
1112
1113 trigger.unregister();
1114
1115 assert_eq!(iii.inner.triggers.lock().unwrap().len(), 0);
1116 }
1117
1118 #[tokio::test]
1119 async fn invoke_function_times_out_and_clears_pending() {
1120 let iii = III::new("ws://localhost:1234");
1121 let result = iii
1122 .call_with_timeout(
1123 "functions.echo",
1124 json!({ "a": 1 }),
1125 Duration::from_millis(10),
1126 )
1127 .await;
1128
1129 assert!(matches!(result, Err(IIIError::Timeout)));
1130 assert!(iii.inner.pending.lock().unwrap().is_empty());
1131 }
1132}