1use std::{
2 collections::{HashMap, HashSet},
3 sync::{
4 Arc, Mutex, MutexGuard,
5 atomic::{AtomicBool, AtomicUsize, Ordering},
6 },
7 time::Duration,
8};
9
10trait MutexExt<T> {
13 fn lock_or_recover(&self) -> MutexGuard<'_, T>;
14}
15
16impl<T> MutexExt<T> for Mutex<T> {
17 fn lock_or_recover(&self) -> MutexGuard<'_, T> {
18 self.lock().unwrap_or_else(|e| e.into_inner())
19 }
20}
21
22use futures_util::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use tokio::{
26 sync::{mpsc, oneshot},
27 time::sleep,
28};
29use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
30use uuid::Uuid;
31
32const SDK_VERSION: &str = env!("CARGO_PKG_VERSION");
33
34use crate::{
35 channels::{ChannelReader, ChannelWriter, StreamChannelRef},
36 error::IIIError,
37 protocol::{
38 ErrorBody, HttpInvocationConfig, Message, RegisterFunctionMessage, RegisterServiceMessage,
39 RegisterTriggerInput, RegisterTriggerMessage, RegisterTriggerTypeMessage, TriggerAction,
40 TriggerRequest, UnregisterTriggerMessage, UnregisterTriggerTypeMessage,
41 },
42 triggers::{Trigger, TriggerConfig, TriggerHandler},
43 types::{Channel, RemoteFunctionData, RemoteFunctionHandler, RemoteTriggerTypeData},
44};
45
46use crate::telemetry;
47use crate::telemetry::types::OtelConfig;
48
49const DEFAULT_TIMEOUT_MS: u64 = 30_000;
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct WorkerInfo {
54 pub id: String,
55 pub name: Option<String>,
56 pub runtime: Option<String>,
57 pub version: Option<String>,
58 pub os: Option<String>,
59 pub ip_address: Option<String>,
60 pub status: String,
61 pub connected_at_ms: u64,
62 pub function_count: usize,
63 pub functions: Vec<String>,
64 pub active_invocations: usize,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct FunctionInfo {
70 pub function_id: String,
71 pub description: Option<String>,
72 pub request_format: Option<Value>,
73 pub response_format: Option<Value>,
74 pub metadata: Option<Value>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct TriggerInfo {
80 pub id: String,
81 pub trigger_type: String,
82 pub function_id: String,
83 pub config: Value,
84 pub metadata: Option<Value>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct TriggerTypeInfo {
90 pub id: String,
91 pub description: String,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub trigger_request_format: Option<Value>,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub call_request_format: Option<Value>,
96}
97
98pub struct RegisterTriggerType<H, C = Value, R = Value> {
109 id: String,
110 description: String,
111 handler: H,
112 trigger_request_format: Option<Value>,
113 call_request_format: Option<Value>,
114 _phantom: std::marker::PhantomData<(C, R)>,
115}
116
117impl<H: TriggerHandler> RegisterTriggerType<H> {
118 pub fn new(id: impl Into<String>, description: impl Into<String>, handler: H) -> Self {
119 Self {
120 id: id.into(),
121 description: description.into(),
122 handler,
123 trigger_request_format: None,
124 call_request_format: None,
125 _phantom: std::marker::PhantomData,
126 }
127 }
128}
129
130impl<H: TriggerHandler, C, R> RegisterTriggerType<H, C, R> {
131 pub fn trigger_request_format<T: schemars::JsonSchema + Serialize>(
135 self,
136 ) -> RegisterTriggerType<H, T, R> {
137 RegisterTriggerType {
138 id: self.id,
139 description: self.description,
140 handler: self.handler,
141 trigger_request_format: json_schema_for::<T>(),
142 call_request_format: self.call_request_format,
143 _phantom: std::marker::PhantomData,
144 }
145 }
146
147 pub fn call_request_format<T: schemars::JsonSchema>(self) -> RegisterTriggerType<H, C, T> {
151 RegisterTriggerType {
152 id: self.id,
153 description: self.description,
154 handler: self.handler,
155 trigger_request_format: self.trigger_request_format,
156 call_request_format: json_schema_for::<T>(),
157 _phantom: std::marker::PhantomData,
158 }
159 }
160}
161
162#[derive(Clone)]
168pub struct TriggerTypeRef<C = Value, R = Value> {
169 iii: III,
170 trigger_type_id: String,
171 _phantom: std::marker::PhantomData<(C, R)>,
172}
173
174impl<C: Serialize, R> TriggerTypeRef<C, R> {
175 pub fn register_trigger(
177 &self,
178 function_id: impl Into<String>,
179 config: C,
180 ) -> Result<Trigger, IIIError> {
181 self.register_trigger_with_metadata(function_id, config, None)
182 }
183
184 pub fn register_trigger_with_metadata(
186 &self,
187 function_id: impl Into<String>,
188 config: C,
189 metadata: Option<Value>,
190 ) -> Result<Trigger, IIIError> {
191 self.iii.register_trigger(RegisterTriggerInput {
192 trigger_type: self.trigger_type_id.clone(),
193 function_id: function_id.into(),
194 config: serde_json::to_value(config).map_err(|e| IIIError::Handler(e.to_string()))?,
195 metadata,
196 })
197 }
198}
199
200impl<C, R> TriggerTypeRef<C, R>
201where
202 R: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
203{
204 pub fn register_function<O, E, F>(&self, id: impl Into<String>, f: F) -> FunctionRef
207 where
208 O: Serialize + schemars::JsonSchema + Send + 'static,
209 E: std::fmt::Display + Send + 'static,
210 F: Fn(R) -> Result<O, E> + Send + Sync + 'static,
211 {
212 self.iii.register_function(RegisterFunction::new(id, f))
213 }
214
215 pub fn register_function_async<O, E, F, Fut>(&self, id: impl Into<String>, f: F) -> FunctionRef
218 where
219 O: Serialize + schemars::JsonSchema + Send + 'static,
220 E: std::fmt::Display + Send + 'static,
221 F: Fn(R) -> Fut + Send + Sync + 'static,
222 Fut: std::future::Future<Output = Result<O, E>> + Send + 'static,
223 {
224 self.iii
225 .register_function(RegisterFunction::new_async(id, f))
226 }
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize, Default)]
231pub struct WorkerTelemetryMeta {
232 #[serde(skip_serializing_if = "Option::is_none")]
233 pub language: Option<String>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 pub project_name: Option<String>,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 pub framework: Option<String>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 pub amplitude_api_key: Option<String>,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct WorkerMetadata {
245 pub runtime: String,
246 pub version: String,
247 pub name: String,
248 pub os: String,
249 #[serde(skip_serializing_if = "Option::is_none")]
250 pub pid: Option<u32>,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 pub telemetry: Option<WorkerTelemetryMeta>,
253}
254
255impl Default for WorkerMetadata {
256 fn default() -> Self {
257 let hostname = hostname::get()
258 .map(|h| h.to_string_lossy().to_string())
259 .unwrap_or_else(|_| "unknown".to_string());
260 let pid = std::process::id();
261 let os_info = format!(
262 "{} {} ({})",
263 std::env::consts::OS,
264 std::env::consts::ARCH,
265 std::env::consts::FAMILY
266 );
267
268 let language = std::env::var("LANG")
269 .or_else(|_| std::env::var("LC_ALL"))
270 .ok()
271 .filter(|s| !s.is_empty())
272 .map(|s| s.split('.').next().unwrap_or(&s).to_string());
273
274 Self {
275 runtime: "rust".to_string(),
276 version: SDK_VERSION.to_string(),
277 name: format!("{}:{}", hostname, pid),
278 os: os_info,
279 pid: Some(pid),
280 telemetry: Some(WorkerTelemetryMeta {
281 language,
282 ..Default::default()
283 }),
284 }
285 }
286}
287
288#[allow(clippy::large_enum_variant)]
289enum Outbound {
290 Message(Message),
291 Shutdown,
292}
293
294type PendingInvocation = oneshot::Sender<Result<Value, IIIError>>;
295
296type WsTx = futures_util::stream::SplitSink<
298 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
299 WsMessage,
300>;
301
302fn inject_trace_headers() -> (Option<String>, Option<String>) {
304 use crate::telemetry::context;
305 (context::inject_traceparent(), context::inject_baggage())
306}
307
308#[derive(Debug, Clone, Copy, PartialEq, Eq)]
310pub enum IIIConnectionState {
311 Disconnected,
312 Connecting,
313 Connected,
314 Reconnecting,
315 Failed,
316}
317
318pub type FunctionsAvailableCallback = Arc<dyn Fn(Vec<FunctionInfo>) + Send + Sync>;
320
321#[derive(Clone)]
322pub struct FunctionRef {
323 pub id: String,
324 unregister_fn: Arc<dyn Fn() + Send + Sync>,
325}
326
327impl FunctionRef {
328 pub fn unregister(&self) {
329 (self.unregister_fn)();
330 }
331}
332
333pub trait IntoFunctionHandler {
334 fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler>;
335}
336
337pub trait IntoFunctionRegistration {
343 fn into_registration(self) -> (RegisterFunctionMessage, Option<RemoteFunctionHandler>);
344}
345
346impl IntoFunctionRegistration for RegisterFunction {
347 fn into_registration(self) -> (RegisterFunctionMessage, Option<RemoteFunctionHandler>) {
348 (self.message, Some(self.handler))
349 }
350}
351
352impl<H: IntoFunctionHandler> IntoFunctionRegistration for (RegisterFunctionMessage, H) {
353 fn into_registration(self) -> (RegisterFunctionMessage, Option<RemoteFunctionHandler>) {
354 let (mut message, handler) = self;
355 let handler = handler.into_parts(&mut message);
356 (message, handler)
357 }
358}
359
360impl IntoFunctionHandler for HttpInvocationConfig {
361 fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
362 message.invocation = Some(self);
363 None
364 }
365}
366
367impl<F, Fut> IntoFunctionHandler for F
368where
369 F: Fn(Value) -> Fut + Send + Sync + 'static,
370 Fut: std::future::Future<Output = Result<Value, IIIError>> + Send + 'static,
371{
372 fn into_parts(self, _message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
373 Some(Arc::new(move |input: Value| Box::pin(self(input))))
374 }
375}
376
377pub struct IIIFn<F = ()> {
386 handler: RemoteFunctionHandler,
387 request_format: Option<Value>,
388 response_format: Option<Value>,
389 _marker: std::marker::PhantomData<F>,
390}
391
392fn json_schema_for<T: schemars::JsonSchema>() -> Option<Value> {
393 serde_json::to_value(
394 schemars::r#gen::SchemaSettings::draft07()
395 .into_generator()
396 .into_root_schema_for::<T>(),
397 )
398 .ok()
399}
400
401#[doc(hidden)]
404pub trait IntoSyncHandler<Marker>: Send + Sync + 'static {
405 fn into_handler(self) -> RemoteFunctionHandler;
406 fn request_format() -> Option<Value> {
407 None
408 }
409 fn response_format() -> Option<Value> {
410 None
411 }
412}
413
414impl<F, T, R, E> IntoSyncHandler<(T, R, E)> for F
416where
417 F: Fn(T) -> Result<R, E> + Send + Sync + 'static,
418 T: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
419 R: serde::Serialize + schemars::JsonSchema + Send + 'static,
420 E: std::fmt::Display + Send + 'static,
421{
422 fn into_handler(self) -> RemoteFunctionHandler {
423 Arc::new(move |input: Value| {
424 let output = serde_json::from_value::<T>(input)
425 .map_err(|e| IIIError::Handler(e.to_string()))
426 .and_then(|arg| (self)(arg).map_err(|e| IIIError::Handler(e.to_string())))
427 .and_then(|val| {
428 serde_json::to_value(&val).map_err(|e| IIIError::Handler(e.to_string()))
429 });
430 Box::pin(async move { output })
431 })
432 }
433
434 fn request_format() -> Option<Value> {
435 json_schema_for::<T>()
436 }
437
438 fn response_format() -> Option<Value> {
439 json_schema_for::<R>()
440 }
441}
442
443pub fn iii_fn<F, M>(f: F) -> IIIFn<F>
454where
455 F: IntoSyncHandler<M>,
456{
457 IIIFn {
458 request_format: F::request_format(),
459 response_format: F::response_format(),
460 handler: f.into_handler(),
461 _marker: std::marker::PhantomData,
462 }
463}
464
465impl<F> IntoFunctionHandler for IIIFn<F> {
466 fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
467 if message.request_format.is_none() {
468 message.request_format = self.request_format;
469 }
470 if message.response_format.is_none() {
471 message.response_format = self.response_format;
472 }
473 Some(self.handler)
474 }
475}
476
477pub struct IIIAsyncFn<F = ()> {
486 handler: RemoteFunctionHandler,
487 request_format: Option<Value>,
488 response_format: Option<Value>,
489 _marker: std::marker::PhantomData<F>,
490}
491
492#[doc(hidden)]
495pub trait IntoAsyncHandler<Marker>: Send + Sync + 'static {
496 fn into_handler(self) -> RemoteFunctionHandler;
497 fn request_format() -> Option<Value> {
498 None
499 }
500 fn response_format() -> Option<Value> {
501 None
502 }
503}
504
505impl<F, T, Fut, R, E> IntoAsyncHandler<(T, Fut, R, E)> for F
507where
508 F: Fn(T) -> Fut + Send + Sync + 'static,
509 T: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
510 Fut: std::future::Future<Output = Result<R, E>> + Send + 'static,
511 R: serde::Serialize + schemars::JsonSchema + Send + 'static,
512 E: std::fmt::Display + Send + 'static,
513{
514 fn into_handler(self) -> RemoteFunctionHandler {
515 Arc::new(
516 move |input: Value| -> std::pin::Pin<
517 Box<dyn std::future::Future<Output = Result<Value, IIIError>> + Send>,
518 > {
519 match serde_json::from_value::<T>(input) {
520 Ok(arg) => {
521 let fut = (self)(arg);
522 Box::pin(async move {
523 fut.await
524 .map_err(|e| IIIError::Handler(e.to_string()))
525 .and_then(|val| {
526 serde_json::to_value(&val)
527 .map_err(|e| IIIError::Handler(e.to_string()))
528 })
529 })
530 }
531 Err(e) => Box::pin(async move { Err(IIIError::Handler(e.to_string())) }),
532 }
533 },
534 )
535 }
536
537 fn request_format() -> Option<Value> {
538 json_schema_for::<T>()
539 }
540
541 fn response_format() -> Option<Value> {
542 json_schema_for::<R>()
543 }
544}
545
546pub fn iii_async_fn<F, M>(f: F) -> IIIAsyncFn<F>
552where
553 F: IntoAsyncHandler<M>,
554{
555 IIIAsyncFn {
556 request_format: F::request_format(),
557 response_format: F::response_format(),
558 handler: f.into_handler(),
559 _marker: std::marker::PhantomData,
560 }
561}
562
563impl<F> IntoFunctionHandler for IIIAsyncFn<F> {
564 fn into_parts(self, message: &mut RegisterFunctionMessage) -> Option<RemoteFunctionHandler> {
565 if message.request_format.is_none() {
566 message.request_format = self.request_format;
567 }
568 if message.response_format.is_none() {
569 message.response_format = self.response_format;
570 }
571 Some(self.handler)
572 }
573}
574
575pub struct RegisterFunction {
584 message: RegisterFunctionMessage,
585 handler: RemoteFunctionHandler,
586}
587
588impl RegisterFunction {
589 pub fn new<F, M>(id: impl Into<String>, f: F) -> Self
591 where
592 F: IntoSyncHandler<M>,
593 {
594 Self {
595 message: RegisterFunctionMessage {
596 id: id.into(),
597 description: None,
598 request_format: F::request_format(),
599 response_format: F::response_format(),
600 metadata: None,
601 invocation: None,
602 },
603 handler: f.into_handler(),
604 }
605 }
606
607 pub fn new_async<F, M>(id: impl Into<String>, f: F) -> Self
609 where
610 F: IntoAsyncHandler<M>,
611 {
612 Self {
613 message: RegisterFunctionMessage {
614 id: id.into(),
615 description: None,
616 request_format: F::request_format(),
617 response_format: F::response_format(),
618 metadata: None,
619 invocation: None,
620 },
621 handler: f.into_handler(),
622 }
623 }
624
625 pub fn description(mut self, desc: impl Into<String>) -> Self {
627 self.message.description = Some(desc.into());
628 self
629 }
630
631 pub fn metadata(mut self, meta: Value) -> Self {
633 self.message.metadata = Some(meta);
634 self
635 }
636
637 pub fn request_format(&self) -> Option<&Value> {
639 self.message.request_format.as_ref()
640 }
641
642 pub fn response_format(&self) -> Option<&Value> {
644 self.message.response_format.as_ref()
645 }
646}
647
648struct IIIInner {
649 address: String,
650 outbound: mpsc::UnboundedSender<Outbound>,
651 receiver: Mutex<Option<mpsc::UnboundedReceiver<Outbound>>>,
652 running: AtomicBool,
653 started: AtomicBool,
654 pending: Mutex<HashMap<Uuid, PendingInvocation>>,
655 functions: Mutex<HashMap<String, RemoteFunctionData>>,
656 trigger_types: Mutex<HashMap<String, RemoteTriggerTypeData>>,
657 triggers: Mutex<HashMap<String, RegisterTriggerMessage>>,
658 services: Mutex<HashMap<String, RegisterServiceMessage>>,
659 worker_metadata: Mutex<Option<WorkerMetadata>>,
660 connection_state: Mutex<IIIConnectionState>,
661 connection_thread: Mutex<Option<std::thread::JoinHandle<()>>>,
662 functions_available_callbacks: Mutex<HashMap<usize, FunctionsAvailableCallback>>,
663 functions_available_callback_counter: AtomicUsize,
664 functions_available_function_id: Mutex<Option<String>>,
665 functions_available_trigger: Mutex<Option<Trigger>>,
666 headers: Mutex<Option<HashMap<String, String>>>,
667 otel_config: Mutex<Option<OtelConfig>>,
668}
669
670#[derive(Clone)]
674pub struct III {
675 inner: Arc<IIIInner>,
676}
677
678pub struct FunctionsAvailableGuard {
680 iii: III,
681 callback_id: usize,
682}
683
684impl Drop for FunctionsAvailableGuard {
685 fn drop(&mut self) {
686 let mut callbacks = self
687 .iii
688 .inner
689 .functions_available_callbacks
690 .lock_or_recover();
691 callbacks.remove(&self.callback_id);
692
693 if callbacks.is_empty() {
694 let mut trigger = self.iii.inner.functions_available_trigger.lock_or_recover();
695 if let Some(trigger) = trigger.take() {
696 trigger.unregister();
697 }
698 }
699 }
700}
701
702impl III {
703 pub fn new(address: &str) -> Self {
705 Self::with_metadata(address, WorkerMetadata::default())
706 }
707
708 pub fn with_metadata(address: &str, metadata: WorkerMetadata) -> Self {
710 let (tx, rx) = mpsc::unbounded_channel();
711 let inner = IIIInner {
712 address: address.into(),
713 outbound: tx,
714 receiver: Mutex::new(Some(rx)),
715 running: AtomicBool::new(false),
716 started: AtomicBool::new(false),
717 pending: Mutex::new(HashMap::new()),
718 functions: Mutex::new(HashMap::new()),
719 trigger_types: Mutex::new(HashMap::new()),
720 triggers: Mutex::new(HashMap::new()),
721 services: Mutex::new(HashMap::new()),
722 worker_metadata: Mutex::new(Some(metadata)),
723 connection_state: Mutex::new(IIIConnectionState::Disconnected),
724 connection_thread: Mutex::new(None),
725 functions_available_callbacks: Mutex::new(HashMap::new()),
726 functions_available_callback_counter: AtomicUsize::new(0),
727 functions_available_function_id: Mutex::new(None),
728 functions_available_trigger: Mutex::new(None),
729 headers: Mutex::new(None),
730 otel_config: Mutex::new(None),
731 };
732 Self {
733 inner: Arc::new(inner),
734 }
735 }
736
737 pub fn address(&self) -> &str {
739 &self.inner.address
740 }
741
742 pub fn set_metadata(&self, metadata: WorkerMetadata) {
744 *self.inner.worker_metadata.lock_or_recover() = Some(metadata);
745 }
746
747 pub fn set_headers(&self, headers: HashMap<String, String>) {
749 *self.inner.headers.lock_or_recover() = Some(headers);
750 }
751
752 pub fn set_otel_config(&self, config: OtelConfig) {
754 *self.inner.otel_config.lock_or_recover() = Some(config);
755 }
756
757 pub(crate) fn connect(&self) {
758 if self.inner.started.swap(true, Ordering::SeqCst) {
759 return;
760 }
761
762 let receiver = self.inner.receiver.lock_or_recover().take();
763 let Some(rx) = receiver else { return };
764
765 self.inner.running.store(true, Ordering::SeqCst);
766
767 let iii = self.clone();
768
769 let otel_config = {
770 let mut config = self
771 .inner
772 .otel_config
773 .lock_or_recover()
774 .take()
775 .unwrap_or_default();
776 if config.engine_ws_url.is_none() {
777 config.engine_ws_url = Some(self.inner.address.clone());
778 }
779 config
780 };
781
782 let handle = std::thread::Builder::new()
788 .name("iii-connection".into())
789 .spawn(move || {
790 let rt = tokio::runtime::Builder::new_current_thread()
791 .enable_all()
792 .build()
793 .expect("failed to create iii connection runtime");
794
795 rt.block_on(async move {
796 let otel_active = telemetry::init_otel(otel_config).await;
797
798 iii.run_connection(rx).await;
799
800 if otel_active {
801 telemetry::shutdown_otel().await;
802 }
803 });
804 })
805 .expect("failed to spawn iii connection thread");
806
807 *self.inner.connection_thread.lock_or_recover() = Some(handle);
808 }
809
810 pub fn shutdown(&self) {
816 self.inner.running.store(false, Ordering::SeqCst);
817 let _ = self.inner.outbound.send(Outbound::Shutdown);
818 self.set_connection_state(IIIConnectionState::Disconnected);
819
820 if let Some(handle) = self.inner.connection_thread.lock_or_recover().take() {
821 let _ = handle.join();
822 }
823 }
824
825 pub async fn shutdown_async(&self) {
837 self.inner.running.store(false, Ordering::SeqCst);
838 let _ = self.inner.outbound.send(Outbound::Shutdown);
839 self.set_connection_state(IIIConnectionState::Disconnected);
840 }
841
842 fn register_function_inner(
843 &self,
844 message: RegisterFunctionMessage,
845 handler: Option<RemoteFunctionHandler>,
846 ) -> FunctionRef {
847 let id = message.id.clone();
848 if id.trim().is_empty() {
849 panic!("id is required");
850 }
851 let data = RemoteFunctionData {
852 message: message.clone(),
853 handler,
854 };
855 let mut funcs = self.inner.functions.lock_or_recover();
856 match funcs.entry(id.clone()) {
857 std::collections::hash_map::Entry::Occupied(_) => {
858 panic!("function id '{}' already registered", id);
859 }
860 std::collections::hash_map::Entry::Vacant(entry) => {
861 entry.insert(data);
862 }
863 }
864 drop(funcs);
865 let _ = self.send_message(message.to_message());
866
867 let iii = self.clone();
868 let unregister_id = id.clone();
869 let unregister_fn = Arc::new(move || {
870 let _ = iii.inner.functions.lock_or_recover().remove(&unregister_id);
871 let _ = iii.send_message(Message::UnregisterFunction {
872 id: unregister_id.clone(),
873 });
874 });
875
876 FunctionRef { id, unregister_fn }
877 }
878
879 pub fn register_function<R: IntoFunctionRegistration>(&self, registration: R) -> FunctionRef {
918 let (message, handler) = registration.into_registration();
919 self.register_function_inner(message, handler)
920 }
921
922 pub fn register_function_with<H: IntoFunctionHandler>(
924 &self,
925 mut message: RegisterFunctionMessage,
926 handler: H,
927 ) -> FunctionRef {
928 let handler = handler.into_parts(&mut message);
929 self.register_function_inner(message, handler)
930 }
931
932 pub fn register_service(&self, message: RegisterServiceMessage) {
937 self.inner
938 .services
939 .lock_or_recover()
940 .insert(message.id.clone(), message.clone());
941 let _ = self.send_message(message.to_message());
942 }
943
944 pub fn register_trigger_type<H, C, R>(
974 &self,
975 registration: RegisterTriggerType<H, C, R>,
976 ) -> TriggerTypeRef<C, R>
977 where
978 H: TriggerHandler + 'static,
979 {
980 let message = RegisterTriggerTypeMessage {
981 id: registration.id,
982 description: registration.description,
983 trigger_request_format: registration.trigger_request_format,
984 call_request_format: registration.call_request_format,
985 };
986
987 let trigger_type_id = message.id.clone();
988
989 self.inner.trigger_types.lock_or_recover().insert(
990 message.id.clone(),
991 RemoteTriggerTypeData {
992 message: message.clone(),
993 handler: Arc::new(registration.handler),
994 },
995 );
996
997 let _ = self.send_message(message.to_message());
998
999 TriggerTypeRef {
1000 iii: self.clone(),
1001 trigger_type_id,
1002 _phantom: std::marker::PhantomData,
1003 }
1004 }
1005
1006 pub fn unregister_trigger_type(&self, id: impl Into<String>) {
1008 let id = id.into();
1009 self.inner.trigger_types.lock_or_recover().remove(&id);
1010 let msg = UnregisterTriggerTypeMessage { id };
1011 let _ = self.send_message(msg.to_message());
1012 }
1013
1014 pub fn register_trigger(&self, input: RegisterTriggerInput) -> Result<Trigger, IIIError> {
1035 let id = Uuid::new_v4().to_string();
1036 let message = RegisterTriggerMessage {
1037 id: id.clone(),
1038 trigger_type: input.trigger_type,
1039 function_id: input.function_id,
1040 config: input.config,
1041 metadata: input.metadata,
1042 };
1043
1044 self.inner
1045 .triggers
1046 .lock_or_recover()
1047 .insert(message.id.clone(), message.clone());
1048 let _ = self.send_message(message.to_message());
1049
1050 let iii = self.clone();
1051 let trigger_type = message.trigger_type.clone();
1052 let unregister_id = message.id.clone();
1053 let unregister_fn = Arc::new(move || {
1054 let _ = iii.inner.triggers.lock_or_recover().remove(&unregister_id);
1055 let msg = UnregisterTriggerMessage {
1056 id: unregister_id.clone(),
1057 trigger_type: trigger_type.clone(),
1058 };
1059 let _ = iii.send_message(msg.to_message());
1060 });
1061
1062 Ok(Trigger::new(unregister_fn))
1063 }
1064
1065 pub async fn trigger(
1105 &self,
1106 request: impl Into<crate::protocol::TriggerRequest>,
1107 ) -> Result<Value, IIIError> {
1108 let req = request.into();
1109 let (tp, bg) = inject_trace_headers();
1110
1111 if matches!(req.action, Some(TriggerAction::Void)) {
1113 self.send_message(Message::InvokeFunction {
1114 invocation_id: None,
1115 function_id: req.function_id,
1116 data: req.payload,
1117 traceparent: tp,
1118 baggage: bg,
1119 action: req.action,
1120 })?;
1121 return Ok(Value::Null);
1122 }
1123
1124 let timeout = Duration::from_millis(req.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
1126 let invocation_id = Uuid::new_v4();
1127 let (tx, rx) = oneshot::channel();
1128
1129 self.inner
1130 .pending
1131 .lock_or_recover()
1132 .insert(invocation_id, tx);
1133
1134 self.send_message(Message::InvokeFunction {
1135 invocation_id: Some(invocation_id),
1136 function_id: req.function_id,
1137 data: req.payload,
1138 traceparent: tp,
1139 baggage: bg,
1140 action: req.action,
1141 })?;
1142
1143 match tokio::time::timeout(timeout, rx).await {
1144 Ok(Ok(result)) => result,
1145 Ok(Err(_)) => Err(IIIError::NotConnected),
1146 Err(_) => {
1147 self.inner.pending.lock_or_recover().remove(&invocation_id);
1148 Err(IIIError::Timeout)
1149 }
1150 }
1151 }
1152
1153 pub fn get_connection_state(&self) -> IIIConnectionState {
1155 *self.inner.connection_state.lock_or_recover()
1156 }
1157
1158 fn set_connection_state(&self, state: IIIConnectionState) {
1159 let mut current = self.inner.connection_state.lock_or_recover();
1160 if *current == state {
1161 return;
1162 }
1163 *current = state;
1164 }
1165
1166 pub async fn list_functions(&self) -> Result<Vec<FunctionInfo>, IIIError> {
1168 let result = self
1169 .trigger(TriggerRequest {
1170 function_id: "engine::functions::list".to_string(),
1171 payload: serde_json::json!({}),
1172 action: None,
1173 timeout_ms: None,
1174 })
1175 .await?;
1176
1177 let functions = result
1178 .get("functions")
1179 .and_then(|v| serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok())
1180 .unwrap_or_default();
1181
1182 Ok(functions)
1183 }
1184
1185 pub fn on_functions_available<F>(&self, callback: F) -> FunctionsAvailableGuard
1188 where
1189 F: Fn(Vec<FunctionInfo>) + Send + Sync + 'static,
1190 {
1191 let callback = Arc::new(callback);
1192 let callback_id = self
1193 .inner
1194 .functions_available_callback_counter
1195 .fetch_add(1, Ordering::Relaxed);
1196
1197 self.inner
1198 .functions_available_callbacks
1199 .lock_or_recover()
1200 .insert(callback_id, callback);
1201
1202 let mut trigger_guard = self.inner.functions_available_trigger.lock_or_recover();
1204 if trigger_guard.is_none() {
1205 let function_id = {
1207 let mut path_guard = self.inner.functions_available_function_id.lock_or_recover();
1208 if path_guard.is_none() {
1209 let path = format!("iii.on_functions_available.{}", Uuid::new_v4());
1210 *path_guard = Some(path.clone());
1211 path
1212 } else {
1213 path_guard.clone().unwrap()
1214 }
1215 };
1216
1217 let function_exists = self
1219 .inner
1220 .functions
1221 .lock_or_recover()
1222 .contains_key(&function_id);
1223 if !function_exists {
1224 let iii = self.clone();
1225 self.register_function_with(
1226 RegisterFunctionMessage {
1227 id: function_id.clone(),
1228 description: None,
1229 request_format: None,
1230 response_format: None,
1231 metadata: None,
1232 invocation: None,
1233 },
1234 move |input: Value| {
1235 let iii = iii.clone();
1236 async move {
1237 let functions = input
1238 .get("functions")
1239 .and_then(|v| {
1240 serde_json::from_value::<Vec<FunctionInfo>>(v.clone()).ok()
1241 })
1242 .unwrap_or_default();
1243
1244 let callbacks =
1245 iii.inner.functions_available_callbacks.lock_or_recover();
1246 for cb in callbacks.values() {
1247 cb(functions.clone());
1248 }
1249 Ok(Value::Null)
1250 }
1251 },
1252 );
1253 }
1254
1255 match self.register_trigger(RegisterTriggerInput {
1256 trigger_type: "engine::functions-available".to_string(),
1257 function_id,
1258 config: serde_json::json!({}),
1259 metadata: None,
1260 }) {
1261 Ok(trigger) => {
1262 *trigger_guard = Some(trigger);
1263 }
1264 Err(err) => {
1265 tracing::warn!(error = %err, "Failed to register functions_available trigger");
1266 }
1267 }
1268 }
1269
1270 FunctionsAvailableGuard {
1271 iii: self.clone(),
1272 callback_id,
1273 }
1274 }
1275
1276 pub async fn list_workers(&self) -> Result<Vec<WorkerInfo>, IIIError> {
1278 let result = self
1279 .trigger(TriggerRequest {
1280 function_id: "engine::workers::list".to_string(),
1281 payload: serde_json::json!({}),
1282 action: None,
1283 timeout_ms: None,
1284 })
1285 .await?;
1286
1287 let workers = result
1288 .get("workers")
1289 .and_then(|v| serde_json::from_value::<Vec<WorkerInfo>>(v.clone()).ok())
1290 .unwrap_or_default();
1291
1292 Ok(workers)
1293 }
1294
1295 pub async fn list_triggers(
1297 &self,
1298 include_internal: bool,
1299 ) -> Result<Vec<TriggerInfo>, IIIError> {
1300 let result = self
1301 .trigger(TriggerRequest {
1302 function_id: "engine::triggers::list".to_string(),
1303 payload: serde_json::json!({ "include_internal": include_internal }),
1304 action: None,
1305 timeout_ms: None,
1306 })
1307 .await?;
1308
1309 let triggers = result
1310 .get("triggers")
1311 .and_then(|v| serde_json::from_value::<Vec<TriggerInfo>>(v.clone()).ok())
1312 .unwrap_or_default();
1313
1314 Ok(triggers)
1315 }
1316
1317 pub async fn list_trigger_types(
1320 &self,
1321 include_internal: bool,
1322 ) -> Result<Vec<TriggerTypeInfo>, IIIError> {
1323 let result = self
1324 .trigger(TriggerRequest {
1325 function_id: "engine::trigger-types::list".to_string(),
1326 payload: serde_json::json!({ "include_internal": include_internal }),
1327 action: None,
1328 timeout_ms: None,
1329 })
1330 .await?;
1331
1332 let trigger_types = result
1333 .get("trigger_types")
1334 .and_then(|v| serde_json::from_value::<Vec<TriggerTypeInfo>>(v.clone()).ok())
1335 .unwrap_or_default();
1336
1337 Ok(trigger_types)
1338 }
1339
1340 pub async fn create_channel(&self, buffer_size: Option<usize>) -> Result<Channel, IIIError> {
1345 let result = self
1346 .trigger(TriggerRequest {
1347 function_id: "engine::channels::create".to_string(),
1348 payload: serde_json::json!({ "buffer_size": buffer_size }),
1349 action: None,
1350 timeout_ms: None,
1351 })
1352 .await?;
1353
1354 let writer_ref: StreamChannelRef = serde_json::from_value(
1355 result
1356 .get("writer")
1357 .cloned()
1358 .ok_or_else(|| IIIError::Serde("missing 'writer' in channel response".into()))?,
1359 )
1360 .map_err(|e| IIIError::Serde(e.to_string()))?;
1361
1362 let reader_ref: StreamChannelRef = serde_json::from_value(
1363 result
1364 .get("reader")
1365 .cloned()
1366 .ok_or_else(|| IIIError::Serde("missing 'reader' in channel response".into()))?,
1367 )
1368 .map_err(|e| IIIError::Serde(e.to_string()))?;
1369
1370 Ok(Channel {
1371 writer: ChannelWriter::new(&self.inner.address, &writer_ref),
1372 reader: ChannelReader::new(&self.inner.address, &reader_ref),
1373 writer_ref,
1374 reader_ref,
1375 })
1376 }
1377
1378 fn register_worker_metadata(&self) {
1380 if let Some(mut metadata) = self.inner.worker_metadata.lock_or_recover().clone() {
1381 let fw = metadata
1382 .telemetry
1383 .as_ref()
1384 .and_then(|t| t.framework.as_deref())
1385 .unwrap_or("");
1386 if fw.is_empty() {
1387 let telem = metadata.telemetry.get_or_insert_with(Default::default);
1388 telem.framework = Some("iii-rust".to_string());
1389 }
1390 if let Ok(value) = serde_json::to_value(metadata) {
1391 let _ = self.send_message(Message::InvokeFunction {
1392 invocation_id: None,
1393 function_id: "engine::workers::register".to_string(),
1394 data: value,
1395 traceparent: None,
1396 baggage: None,
1397 action: Some(TriggerAction::Void),
1398 });
1399 }
1400 }
1401 }
1402
1403 fn send_message(&self, message: Message) -> Result<(), IIIError> {
1404 if !self.inner.running.load(Ordering::SeqCst) {
1405 return Ok(());
1406 }
1407
1408 self.inner
1409 .outbound
1410 .send(Outbound::Message(message))
1411 .map_err(|_| IIIError::NotConnected)
1412 }
1413
1414 async fn run_connection(&self, mut rx: mpsc::UnboundedReceiver<Outbound>) {
1415 let mut queue: Vec<Message> = Vec::new();
1416 let mut has_connected_before = false;
1417
1418 while self.inner.running.load(Ordering::SeqCst) {
1419 self.set_connection_state(if has_connected_before {
1420 IIIConnectionState::Reconnecting
1421 } else {
1422 IIIConnectionState::Connecting
1423 });
1424
1425 let custom_headers = self.inner.headers.lock_or_recover().clone();
1426
1427 let connect_result = if let Some(ref h) = custom_headers {
1428 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
1429 use tokio_tungstenite::tungstenite::http;
1430 let mut request = self
1431 .inner
1432 .address
1433 .as_str()
1434 .into_client_request()
1435 .expect("valid ws request");
1436 for (k, v) in h {
1437 if let (Ok(name), Ok(val)) = (
1438 http::header::HeaderName::from_bytes(k.as_bytes()),
1439 http::header::HeaderValue::from_str(v),
1440 ) {
1441 request.headers_mut().insert(name, val);
1442 }
1443 }
1444 connect_async(request).await
1445 } else {
1446 connect_async(&self.inner.address).await
1447 };
1448
1449 match connect_result {
1450 Ok((stream, _)) => {
1451 tracing::info!(address = %self.inner.address, "iii connected");
1452 has_connected_before = true;
1453 self.set_connection_state(IIIConnectionState::Connected);
1454 let (mut ws_tx, mut ws_rx) = stream.split();
1455
1456 queue.extend(self.collect_registrations());
1457 Self::dedupe_registrations(&mut queue);
1458 if let Err(err) = self.flush_queue(&mut ws_tx, &mut queue).await {
1459 tracing::warn!(error = %err, "failed to flush queue");
1460 sleep(Duration::from_secs(2)).await;
1461 continue;
1462 }
1463
1464 self.register_worker_metadata();
1466
1467 let mut should_reconnect = false;
1468
1469 while self.inner.running.load(Ordering::SeqCst) && !should_reconnect {
1470 tokio::select! {
1471 outgoing = rx.recv() => {
1472 match outgoing {
1473 Some(Outbound::Message(message)) => {
1474 if let Err(err) = self.send_ws(&mut ws_tx, &message).await {
1475 tracing::warn!(error = %err, "send failed; reconnecting");
1476 queue.push(message);
1477 should_reconnect = true;
1478 }
1479 }
1480 Some(Outbound::Shutdown) => {
1481 self.inner.running.store(false, Ordering::SeqCst);
1482 return;
1483 }
1484 None => {
1485 self.inner.running.store(false, Ordering::SeqCst);
1486 return;
1487 }
1488 }
1489 }
1490 incoming = ws_rx.next() => {
1491 match incoming {
1492 Some(Ok(frame)) => {
1493 if let Err(err) = self.handle_frame(frame) {
1494 tracing::warn!(error = %err, "failed to handle frame");
1495 }
1496 }
1497 Some(Err(err)) => {
1498 tracing::warn!(error = %err, "websocket receive error");
1499 should_reconnect = true;
1500 }
1501 None => {
1502 should_reconnect = true;
1503 }
1504 }
1505 }
1506 }
1507 }
1508 }
1509 Err(err) => {
1510 tracing::warn!(error = %err, "failed to connect; retrying");
1511 }
1512 }
1513
1514 if self.inner.running.load(Ordering::SeqCst) {
1515 sleep(Duration::from_secs(2)).await;
1516 }
1517 }
1518 }
1519
1520 fn collect_registrations(&self) -> Vec<Message> {
1521 let mut messages = Vec::new();
1522
1523 for trigger_type in self.inner.trigger_types.lock_or_recover().values() {
1524 messages.push(trigger_type.message.to_message());
1525 }
1526
1527 for service in self.inner.services.lock_or_recover().values() {
1528 messages.push(service.to_message());
1529 }
1530
1531 for function in self.inner.functions.lock_or_recover().values() {
1532 messages.push(function.message.to_message());
1533 }
1534
1535 for trigger in self.inner.triggers.lock_or_recover().values() {
1536 messages.push(trigger.to_message());
1537 }
1538
1539 messages
1540 }
1541
1542 fn dedupe_registrations(queue: &mut Vec<Message>) {
1543 let mut seen = HashSet::new();
1544 let mut deduped_rev = Vec::with_capacity(queue.len());
1545
1546 for message in queue.iter().rev() {
1547 let key = match message {
1548 Message::RegisterTriggerType { id, .. } => format!("trigger_type:{id}"),
1549 Message::RegisterTrigger { id, .. } => format!("trigger:{id}"),
1550 Message::RegisterFunction { id, .. } => {
1551 format!("function:{id}")
1552 }
1553 Message::RegisterService { id, .. } => format!("service:{id}"),
1554 _ => {
1555 deduped_rev.push(message.clone());
1556 continue;
1557 }
1558 };
1559
1560 if seen.insert(key) {
1561 deduped_rev.push(message.clone());
1562 }
1563 }
1564
1565 deduped_rev.reverse();
1566 *queue = deduped_rev;
1567 }
1568
1569 async fn flush_queue(
1570 &self,
1571 ws_tx: &mut WsTx,
1572 queue: &mut Vec<Message>,
1573 ) -> Result<(), IIIError> {
1574 let mut drained = Vec::new();
1575 std::mem::swap(queue, &mut drained);
1576
1577 let mut iter = drained.into_iter();
1578 while let Some(message) = iter.next() {
1579 if let Err(err) = self.send_ws(ws_tx, &message).await {
1580 queue.push(message);
1581 queue.extend(iter);
1582 return Err(err);
1583 }
1584 }
1585
1586 Ok(())
1587 }
1588
1589 async fn send_ws(&self, ws_tx: &mut WsTx, message: &Message) -> Result<(), IIIError> {
1590 let payload = serde_json::to_string(message)?;
1591 ws_tx.send(WsMessage::Text(payload.into())).await?;
1592 Ok(())
1593 }
1594
1595 fn handle_frame(&self, frame: WsMessage) -> Result<(), IIIError> {
1596 match frame {
1597 WsMessage::Text(text) => self.handle_message(&text),
1598 WsMessage::Binary(bytes) => {
1599 let text = String::from_utf8_lossy(&bytes).to_string();
1600 self.handle_message(&text)
1601 }
1602 _ => Ok(()),
1603 }
1604 }
1605
1606 fn handle_message(&self, payload: &str) -> Result<(), IIIError> {
1607 let message: Message = serde_json::from_str(payload)?;
1608
1609 match message {
1610 Message::InvocationResult {
1611 invocation_id,
1612 result,
1613 error,
1614 ..
1615 } => {
1616 self.handle_invocation_result(invocation_id, result, error);
1617 }
1618 Message::InvokeFunction {
1619 invocation_id,
1620 function_id,
1621 data,
1622 traceparent,
1623 baggage,
1624 action: _,
1625 } => {
1626 self.handle_invoke_function(invocation_id, function_id, data, traceparent, baggage);
1627 }
1628 Message::RegisterTrigger {
1629 id,
1630 trigger_type,
1631 function_id,
1632 config,
1633 metadata,
1634 } => {
1635 self.handle_register_trigger(id, trigger_type, function_id, config, metadata);
1636 }
1637 Message::Ping => {
1638 let _ = self.send_message(Message::Pong);
1639 }
1640 Message::WorkerRegistered { worker_id } => {
1641 tracing::debug!(worker_id = %worker_id, "Worker registered");
1642 }
1643 _ => {}
1644 }
1645
1646 Ok(())
1647 }
1648
1649 fn handle_invocation_result(
1650 &self,
1651 invocation_id: Uuid,
1652 result: Option<Value>,
1653 error: Option<ErrorBody>,
1654 ) {
1655 let sender = self.inner.pending.lock_or_recover().remove(&invocation_id);
1656 if let Some(sender) = sender {
1657 let result = match error {
1658 Some(error) => Err(IIIError::Remote {
1659 code: error.code,
1660 message: error.message,
1661 stacktrace: error.stacktrace,
1662 }),
1663 None => Ok(result.unwrap_or(Value::Null)),
1664 };
1665 let _ = sender.send(result);
1666 }
1667 }
1668
1669 fn handle_invoke_function(
1670 &self,
1671 invocation_id: Option<Uuid>,
1672 function_id: String,
1673 data: Value,
1674 traceparent: Option<String>,
1675 baggage: Option<String>,
1676 ) {
1677 tracing::debug!(function_id = %function_id, traceparent = ?traceparent, baggage = ?baggage, "Invoking function");
1678
1679 let func_data = self
1680 .inner
1681 .functions
1682 .lock_or_recover()
1683 .get(&function_id)
1684 .cloned();
1685 let handler = func_data.as_ref().and_then(|d| d.handler.clone());
1686
1687 let Some(handler) = handler else {
1688 let (code, message) = match &func_data {
1689 Some(_) => (
1690 "function_not_invokable".to_string(),
1691 "Function is HTTP-invoked and cannot be invoked locally".to_string(),
1692 ),
1693 None => (
1694 "function_not_found".to_string(),
1695 "Function not found".to_string(),
1696 ),
1697 };
1698 tracing::warn!(function_id = %function_id, "Invocation: {}", message);
1699
1700 if let Some(invocation_id) = invocation_id {
1701 let (resp_tp, resp_bg) = inject_trace_headers();
1702
1703 let error = ErrorBody {
1704 code,
1705 message,
1706 stacktrace: None,
1707 };
1708 let result = self.send_message(Message::InvocationResult {
1709 invocation_id,
1710 function_id,
1711 result: None,
1712 error: Some(error),
1713 traceparent: resp_tp,
1714 baggage: resp_bg,
1715 });
1716
1717 if let Err(err) = result {
1718 tracing::warn!(error = %err, "error sending invocation result");
1719 }
1720 }
1721 return;
1722 };
1723
1724 let iii = self.clone();
1725
1726 tokio::spawn(async move {
1727 let otel_cx = {
1733 use crate::telemetry::context::extract_context;
1734 use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer};
1735
1736 let parent_cx = extract_context(traceparent.as_deref(), baggage.as_deref());
1737 let tracer = opentelemetry::global::tracer("iii-rust-sdk");
1738 let span = tracer
1739 .span_builder(format!("call {}", function_id))
1740 .with_kind(SpanKind::Server)
1741 .start_with_context(&tracer, &parent_cx);
1742 parent_cx.with_span(span)
1743 };
1744
1745 let result = {
1746 use opentelemetry::trace::FutureExt as OtelFutureExt;
1747 handler(data).with_context(otel_cx.clone()).await
1748 };
1749
1750 let mut error_stacktrace: Option<String> = None;
1752 {
1753 use opentelemetry::KeyValue;
1754 use opentelemetry::trace::{Status, TraceContextExt};
1755 let span = otel_cx.span();
1756 match &result {
1757 Ok(_) => span.set_status(Status::Ok),
1758 Err(err) => {
1759 let (exc_type, exc_message, stacktrace) = match err {
1760 IIIError::Remote {
1761 code,
1762 message,
1763 stacktrace,
1764 } => (
1765 code.clone(),
1766 message.clone(),
1767 stacktrace.clone().unwrap_or_else(|| {
1768 std::backtrace::Backtrace::force_capture().to_string()
1769 }),
1770 ),
1771 other => (
1772 "InvocationError".to_string(),
1773 other.to_string(),
1774 std::backtrace::Backtrace::force_capture().to_string(),
1775 ),
1776 };
1777 span.set_status(Status::error(exc_message.clone()));
1778 span.add_event(
1779 "exception",
1780 vec![
1781 KeyValue::new("exception.type", exc_type),
1782 KeyValue::new("exception.message", exc_message),
1783 KeyValue::new("exception.stacktrace", stacktrace.clone()),
1784 ],
1785 );
1786 error_stacktrace = Some(stacktrace);
1787 }
1788 }
1789 }
1790
1791 if let Some(invocation_id) = invocation_id {
1792 let (resp_tp, resp_bg) = {
1796 let _guard = otel_cx.attach();
1797 inject_trace_headers()
1798 };
1799
1800 let message = match result {
1801 Ok(value) => Message::InvocationResult {
1802 invocation_id,
1803 function_id,
1804 result: Some(value),
1805 error: None,
1806 traceparent: resp_tp,
1807 baggage: resp_bg,
1808 },
1809 Err(err) => {
1810 let error_body = match err {
1811 IIIError::Remote {
1812 code,
1813 message,
1814 stacktrace,
1815 } => ErrorBody {
1816 code,
1817 message,
1818 stacktrace: stacktrace.or(error_stacktrace).or_else(|| {
1819 Some(std::backtrace::Backtrace::force_capture().to_string())
1820 }),
1821 },
1822 other => ErrorBody {
1823 code: "invocation_failed".to_string(),
1824 message: other.to_string(),
1825 stacktrace: error_stacktrace.or_else(|| {
1826 Some(std::backtrace::Backtrace::force_capture().to_string())
1827 }),
1828 },
1829 };
1830 Message::InvocationResult {
1831 invocation_id,
1832 function_id,
1833 result: None,
1834 error: Some(error_body),
1835 traceparent: resp_tp,
1836 baggage: resp_bg,
1837 }
1838 }
1839 };
1840
1841 let _ = iii.send_message(message);
1842 } else if let Err(err) = result {
1843 tracing::warn!(error = %err, "error handling async invocation");
1844 }
1845 });
1846 }
1847
1848 fn handle_register_trigger(
1849 &self,
1850 id: String,
1851 trigger_type: String,
1852 function_id: String,
1853 config: Value,
1854 metadata: Option<Value>,
1855 ) {
1856 let handler = self
1857 .inner
1858 .trigger_types
1859 .lock_or_recover()
1860 .get(&trigger_type)
1861 .map(|data| data.handler.clone());
1862
1863 let iii = self.clone();
1864
1865 tokio::spawn(async move {
1866 let message = if let Some(handler) = handler {
1867 let config = TriggerConfig {
1868 id: id.clone(),
1869 function_id: function_id.clone(),
1870 config,
1871 metadata,
1872 };
1873
1874 match handler.register_trigger(config).await {
1875 Ok(()) => Message::TriggerRegistrationResult {
1876 id,
1877 trigger_type,
1878 function_id,
1879 error: None,
1880 },
1881 Err(err) => Message::TriggerRegistrationResult {
1882 id,
1883 trigger_type,
1884 function_id,
1885 error: Some(ErrorBody {
1886 code: "trigger_registration_failed".to_string(),
1887 message: err.to_string(),
1888 stacktrace: None,
1889 }),
1890 },
1891 }
1892 } else {
1893 Message::TriggerRegistrationResult {
1894 id,
1895 trigger_type,
1896 function_id,
1897 error: Some(ErrorBody {
1898 code: "trigger_type_not_found".to_string(),
1899 message: "Trigger type not found".to_string(),
1900 stacktrace: None,
1901 }),
1902 }
1903 };
1904
1905 let _ = iii.send_message(message);
1906 });
1907 }
1908}
1909
1910#[cfg(test)]
1911mod tests {
1912 use std::collections::HashMap;
1913
1914 use serde_json::json;
1915
1916 use super::*;
1917 use crate::{
1918 InitOptions,
1919 protocol::{HttpInvocationConfig, HttpMethod, RegisterTriggerInput},
1920 register_worker,
1921 };
1922
1923 #[tokio::test]
1924 async fn register_trigger_unregister_removes_entry() {
1925 let iii = register_worker("ws://localhost:1234", InitOptions::default());
1926 let trigger = iii
1927 .register_trigger(RegisterTriggerInput {
1928 trigger_type: "demo".to_string(),
1929 function_id: "functions.echo".to_string(),
1930 config: json!({ "foo": "bar" }),
1931 metadata: None,
1932 })
1933 .unwrap();
1934
1935 assert_eq!(iii.inner.triggers.lock().unwrap().len(), 1);
1936
1937 trigger.unregister();
1938
1939 assert_eq!(iii.inner.triggers.lock().unwrap().len(), 0);
1940 }
1941
1942 #[tokio::test]
1943 async fn register_function_with_http_config_stores_and_unregister_removes() {
1944 let iii = register_worker("ws://localhost:1234", InitOptions::default());
1945 let config = HttpInvocationConfig {
1946 url: "https://example.com/invoke".to_string(),
1947 method: HttpMethod::Post,
1948 timeout_ms: Some(30000),
1949 headers: HashMap::new(),
1950 auth: None,
1951 };
1952
1953 let func_ref = iii.register_function_with(
1954 RegisterFunctionMessage {
1955 id: "external::my_lambda".to_string(),
1956 description: None,
1957 request_format: None,
1958 response_format: None,
1959 metadata: None,
1960 invocation: None,
1961 },
1962 config,
1963 );
1964
1965 assert_eq!(func_ref.id, "external::my_lambda");
1966 assert_eq!(iii.inner.functions.lock().unwrap().len(), 1);
1967
1968 func_ref.unregister();
1969
1970 assert_eq!(iii.inner.functions.lock().unwrap().len(), 0);
1971 }
1972
1973 #[tokio::test]
1974 #[should_panic(expected = "id is required")]
1975 async fn register_function_rejects_empty_id() {
1976 let iii = register_worker("ws://localhost:1234", InitOptions::default());
1977 let config = HttpInvocationConfig {
1978 url: "https://example.com/invoke".to_string(),
1979 method: HttpMethod::Post,
1980 timeout_ms: None,
1981 headers: HashMap::new(),
1982 auth: None,
1983 };
1984
1985 iii.register_function_with(
1986 RegisterFunctionMessage {
1987 id: "".to_string(),
1988 description: None,
1989 request_format: None,
1990 response_format: None,
1991 metadata: None,
1992 invocation: None,
1993 },
1994 config,
1995 );
1996 }
1997
1998 #[tokio::test]
1999 async fn invoke_function_times_out_and_clears_pending() {
2000 let iii = register_worker("ws://localhost:1234", InitOptions::default());
2001 let result = iii
2002 .trigger(TriggerRequest {
2003 function_id: "functions.echo".to_string(),
2004 payload: json!({ "a": 1 }),
2005 action: None,
2006 timeout_ms: Some(10),
2007 })
2008 .await;
2009
2010 assert!(matches!(result, Err(IIIError::Timeout)));
2011 assert!(iii.inner.pending.lock().unwrap().is_empty());
2012 }
2013}