obeli_sk_concepts/
lib.rs

1use ::serde::{Deserialize, Serialize};
2use arbitrary::Arbitrary;
3use assert_matches::assert_matches;
4use async_trait::async_trait;
5use derivative::Derivative;
6pub use indexmap;
7use indexmap::IndexMap;
8use opentelemetry::propagation::{Extractor, Injector};
9pub use prefixed_ulid::ExecutionId;
10use prefixed_ulid::{ExecutionIdDerived, ExecutionIdParseError};
11use serde_json::Value;
12use std::{
13    borrow::Borrow,
14    fmt::{Debug, Display},
15    hash::Hash,
16    marker::PhantomData,
17    ops::Deref,
18    str::FromStr,
19    sync::Arc,
20    time::Duration,
21};
22use storage::{PendingStateFinishedError, PendingStateFinishedResultKind};
23use tracing::Span;
24use val_json::{
25    type_wrapper::{TypeConversionError, TypeWrapper},
26    wast_val::{WastVal, WastValWithType},
27    wast_val_ser::params,
28};
29use wasmtime::component::{Type, Val};
30
31#[cfg(feature = "rusqlite")]
32mod rusqlite_ext;
33pub mod storage;
34pub mod time;
35
36pub const NAMESPACE_OBELISK: &str = "obelisk";
37pub const SUFFIX_PKG_EXT: &str = "-obelisk-ext";
38
39pub type FinishedExecutionResult = Result<SupportedFunctionReturnValue, FinishedExecutionError>;
40
41#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
42pub enum FinishedExecutionError {
43    // Activity only
44    #[error("permanent timeout")]
45    PermanentTimeout,
46    // Workflow only
47    #[error("unhandled child execution error {child_execution_id}")]
48    UnhandledChildExecutionError {
49        child_execution_id: ExecutionIdDerived,
50        root_cause_id: ExecutionIdDerived,
51    },
52    #[error("permanent failure: {reason_full}")]
53    PermanentFailure {
54        // Exists just for extracting reason of an activity trap, to avoid "activity trap: " prefix.
55        reason_inner: String,
56        // Contains reason_inner embedded in the error message
57        reason_full: String,
58        kind: PermanentFailureKind,
59        detail: Option<String>,
60    },
61}
62impl FinishedExecutionError {
63    #[must_use]
64    pub fn as_pending_state_finished_error(&self) -> PendingStateFinishedError {
65        match self {
66            FinishedExecutionError::PermanentTimeout => PendingStateFinishedError::Timeout,
67            FinishedExecutionError::UnhandledChildExecutionError { .. } => {
68                PendingStateFinishedError::UnhandledChildExecutionError
69            }
70            FinishedExecutionError::PermanentFailure { .. } => {
71                PendingStateFinishedError::ExecutionFailure
72            }
73        }
74    }
75}
76
77#[derive(Debug, Clone, Copy, derive_more::Display, PartialEq, Eq, Serialize, Deserialize)]
78#[serde(rename_all = "snake_case")]
79pub enum PermanentFailureKind {
80    /// Applicable to Workflow
81    NondeterminismDetected,
82    /// Applicable to Workflow, Activity
83    ParamsParsingError,
84    /// Applicable to Workflow, Activity
85    CannotInstantiate,
86    /// Applicable to Workflow, Activity
87    ResultParsingError,
88    /// Applicable to Workflow
89    ImportedFunctionCallError,
90    /// Applicable to Activity
91    ActivityTrap,
92    /// Applicable to Workflow
93    WorkflowTrap,
94    /// Applicable to Workflow
95    JoinSetNameConflict,
96    /// Applicable to webhook endpoint
97    WebhookEndpointError,
98}
99
100#[derive(Debug, Clone, Copy, derive_more::Display, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "snake_case")]
102pub enum TrapKind {
103    #[display("trap")]
104    Trap,
105    #[display("post_return_trap")]
106    PostReturnTrap,
107}
108
109#[derive(Clone, Eq, derive_more::Display)]
110pub enum StrVariant {
111    Static(&'static str),
112    Arc(Arc<str>),
113}
114
115impl StrVariant {
116    #[must_use]
117    pub const fn empty() -> StrVariant {
118        StrVariant::Static("")
119    }
120}
121
122impl From<String> for StrVariant {
123    fn from(value: String) -> Self {
124        StrVariant::Arc(Arc::from(value))
125    }
126}
127
128impl From<&'static str> for StrVariant {
129    fn from(value: &'static str) -> Self {
130        StrVariant::Static(value)
131    }
132}
133
134impl PartialEq for StrVariant {
135    fn eq(&self, other: &Self) -> bool {
136        match (self, other) {
137            (Self::Static(left), Self::Static(right)) => left == right,
138            (Self::Static(left), Self::Arc(right)) => *left == right.deref(),
139            (Self::Arc(left), Self::Arc(right)) => left == right,
140            (Self::Arc(left), Self::Static(right)) => left.deref() == *right,
141        }
142    }
143}
144
145impl Hash for StrVariant {
146    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
147        match self {
148            StrVariant::Static(val) => val.hash(state),
149            StrVariant::Arc(val) => {
150                let str: &str = val.deref();
151                str.hash(state);
152            }
153        }
154    }
155}
156
157impl Debug for StrVariant {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        Display::fmt(self, f)
160    }
161}
162
163impl Deref for StrVariant {
164    type Target = str;
165    fn deref(&self) -> &Self::Target {
166        match self {
167            Self::Arc(v) => v,
168            Self::Static(v) => v,
169        }
170    }
171}
172
173impl AsRef<str> for StrVariant {
174    fn as_ref(&self) -> &str {
175        match self {
176            Self::Arc(v) => v,
177            Self::Static(v) => v,
178        }
179    }
180}
181
182mod serde_strvariant {
183    use crate::StrVariant;
184    use serde::{
185        de::{self, Visitor},
186        Deserialize, Deserializer, Serialize, Serializer,
187    };
188    use std::{ops::Deref, sync::Arc};
189
190    impl Serialize for StrVariant {
191        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
192        where
193            S: Serializer,
194        {
195            serializer.serialize_str(self.deref())
196        }
197    }
198
199    impl<'de> Deserialize<'de> for StrVariant {
200        fn deserialize<D>(deserializer: D) -> Result<StrVariant, D::Error>
201        where
202            D: Deserializer<'de>,
203        {
204            deserializer.deserialize_str(StrVariantVisitor)
205        }
206    }
207
208    struct StrVariantVisitor;
209
210    impl Visitor<'_> for StrVariantVisitor {
211        type Value = StrVariant;
212
213        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
214            formatter.write_str("a string")
215        }
216
217        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
218        where
219            E: de::Error,
220        {
221            Ok(StrVariant::Arc(Arc::from(v)))
222        }
223    }
224}
225
226#[derive(Hash, Clone, PartialEq, Eq, derive_more::Display, Serialize, Deserialize)]
227#[display("{value}")]
228#[serde(transparent)]
229pub struct Name<T> {
230    value: StrVariant,
231    #[serde(skip)]
232    phantom_data: PhantomData<fn(T) -> T>,
233}
234
235impl<T> Name<T> {
236    #[must_use]
237    pub fn new_arc(value: Arc<str>) -> Self {
238        Self {
239            value: StrVariant::Arc(value),
240            phantom_data: PhantomData,
241        }
242    }
243
244    #[must_use]
245    pub const fn new_static(value: &'static str) -> Self {
246        Self {
247            value: StrVariant::Static(value),
248            phantom_data: PhantomData,
249        }
250    }
251}
252
253impl<T> Debug for Name<T> {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        Display::fmt(&self, f)
256    }
257}
258
259impl<T> Deref for Name<T> {
260    type Target = str;
261
262    fn deref(&self) -> &Self::Target {
263        self.value.deref()
264    }
265}
266
267impl<T> Borrow<str> for Name<T> {
268    fn borrow(&self) -> &str {
269        self.deref()
270    }
271}
272
273impl<T> From<String> for Name<T> {
274    fn from(value: String) -> Self {
275        Self::new_arc(Arc::from(value))
276    }
277}
278
279#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
280pub struct PkgFqn {
281    pub namespace: String,
282    pub package_name: String,
283    pub version: Option<String>,
284}
285impl Display for PkgFqn {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        let PkgFqn {
288            namespace,
289            package_name,
290            version,
291        } = self;
292        if let Some(version) = version {
293            write!(f, "{namespace}:{package_name}@{version}")
294        } else {
295            write!(f, "{namespace}:{package_name}")
296        }
297    }
298}
299
300impl PkgFqn {
301    #[must_use]
302    pub fn is_extension(&self) -> bool {
303        self.package_name.ends_with(SUFFIX_PKG_EXT)
304    }
305
306    #[must_use]
307    pub fn package_strip_extension_suffix(&self) -> Option<&str> {
308        self.package_name.as_str().strip_suffix(SUFFIX_PKG_EXT)
309    }
310
311    #[must_use]
312    pub fn is_namespace_obelisk(&self) -> bool {
313        self.namespace == NAMESPACE_OBELISK
314    }
315
316    #[must_use]
317    pub fn ifc_fqn_name(&self, ifc_name: &str) -> IfcFqnName {
318        IfcFqnName::from_parts(
319            &self.namespace,
320            &self.package_name,
321            ifc_name,
322            self.version.as_deref(),
323        )
324    }
325}
326
327#[derive(Hash, Clone, PartialEq, Eq)]
328pub struct IfcFqnMarker;
329
330pub type IfcFqnName = Name<IfcFqnMarker>; // namespace:name/ifc_name OR namespace:name/ifc_name@version
331
332impl IfcFqnName {
333    #[must_use]
334    pub fn namespace(&self) -> &str {
335        self.deref().split_once(':').unwrap().0
336    }
337
338    #[must_use]
339    pub fn package_name(&self) -> &str {
340        let after_colon = self.deref().split_once(':').unwrap().1;
341        after_colon.split_once('/').unwrap().0
342    }
343
344    #[must_use]
345    pub fn version(&self) -> Option<&str> {
346        self.deref().split_once('@').map(|(_, version)| version)
347    }
348
349    #[must_use]
350    pub fn pkg_fqn_name(&self) -> PkgFqn {
351        let (namespace, rest) = self.deref().split_once(':').unwrap();
352        let (package_name, rest) = rest.split_once('/').unwrap();
353        let version = rest.split_once('@').map(|(_, version)| version);
354        PkgFqn {
355            namespace: namespace.to_string(),
356            package_name: package_name.to_string(),
357            version: version.map(std::string::ToString::to_string),
358        }
359    }
360
361    #[must_use]
362    pub fn ifc_name(&self) -> &str {
363        let after_colon = self.deref().split_once(':').unwrap().1;
364        let after_slash = after_colon.split_once('/').unwrap().1;
365        after_slash
366            .split_once('@')
367            .map_or(after_slash, |(ifc, _)| ifc)
368    }
369
370    #[must_use]
371    pub fn from_parts(
372        namespace: &str,
373        package_name: &str,
374        ifc_name: &str,
375        version: Option<&str>,
376    ) -> Self {
377        let mut str = format!("{namespace}:{package_name}/{ifc_name}");
378        if let Some(version) = version {
379            str += "@";
380            str += version;
381        }
382        Self::new_arc(Arc::from(str))
383    }
384
385    #[must_use]
386    pub fn is_extension(&self) -> bool {
387        self.package_name().ends_with(SUFFIX_PKG_EXT)
388    }
389
390    #[must_use]
391    pub fn package_strip_extension_suffix(&self) -> Option<&str> {
392        self.package_name().strip_suffix(SUFFIX_PKG_EXT)
393    }
394
395    #[must_use]
396    pub fn is_namespace_obelisk(&self) -> bool {
397        self.namespace() == NAMESPACE_OBELISK
398    }
399}
400
401#[derive(Hash, Clone, PartialEq, Eq)]
402pub struct FnMarker;
403
404pub type FnName = Name<FnMarker>;
405
406#[derive(Hash, Clone, PartialEq, Eq, Serialize, Deserialize)]
407pub struct FunctionFqn {
408    pub ifc_fqn: IfcFqnName,
409    pub function_name: FnName,
410}
411
412impl FunctionFqn {
413    #[must_use]
414    pub fn new_arc(ifc_fqn: Arc<str>, function_name: Arc<str>) -> Self {
415        Self {
416            ifc_fqn: Name::new_arc(ifc_fqn),
417            function_name: Name::new_arc(function_name),
418        }
419    }
420
421    #[must_use]
422    pub const fn new_static(ifc_fqn: &'static str, function_name: &'static str) -> Self {
423        Self {
424            ifc_fqn: Name::new_static(ifc_fqn),
425            function_name: Name::new_static(function_name),
426        }
427    }
428
429    #[must_use]
430    pub const fn new_static_tuple(tuple: (&'static str, &'static str)) -> Self {
431        Self::new_static(tuple.0, tuple.1)
432    }
433
434    pub fn try_from_tuple(
435        ifc_fqn: &str,
436        function_name: &str,
437    ) -> Result<Self, FunctionFqnParseError> {
438        if ifc_fqn.contains('.') || function_name.contains('.') {
439            Err(FunctionFqnParseError::DelimiterFoundMoreThanOnce)
440        } else {
441            Ok(Self::new_arc(Arc::from(ifc_fqn), Arc::from(function_name)))
442        }
443    }
444}
445
446#[derive(Debug, thiserror::Error)]
447pub enum FunctionFqnParseError {
448    #[error("delimiter `.` not found")]
449    DelimiterNotFound,
450    #[error("delimiter `.` found more than once")]
451    DelimiterFoundMoreThanOnce,
452}
453
454impl FromStr for FunctionFqn {
455    type Err = FunctionFqnParseError;
456
457    fn from_str(s: &str) -> Result<Self, Self::Err> {
458        if let Some((ifc_fqn, function_name)) = s.split_once('.') {
459            if function_name.contains('.') {
460                Err(FunctionFqnParseError::DelimiterFoundMoreThanOnce)
461            } else {
462                Ok(Self::new_arc(Arc::from(ifc_fqn), Arc::from(function_name)))
463            }
464        } else {
465            Err(FunctionFqnParseError::DelimiterNotFound)
466        }
467    }
468}
469
470impl Display for FunctionFqn {
471    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        write!(
473            f,
474            "{ifc_fqn}.{function_name}",
475            ifc_fqn = self.ifc_fqn,
476            function_name = self.function_name
477        )
478    }
479}
480
481impl Debug for FunctionFqn {
482    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        Display::fmt(&self, f)
484    }
485}
486
487impl<'a> arbitrary::Arbitrary<'a> for FunctionFqn {
488    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
489        let illegal = [':', '@', '.'];
490        let namespace = u.arbitrary::<String>()?.replace(illegal, "");
491        let pkg_name = u.arbitrary::<String>()?.replace(illegal, "");
492        let ifc_name = u.arbitrary::<String>()?.replace(illegal, "");
493        let fn_name = u.arbitrary::<String>()?.replace(illegal, "");
494
495        Ok(FunctionFqn::new_arc(
496            Arc::from(format!("{namespace}:{pkg_name}/{ifc_name}")),
497            Arc::from(fn_name),
498        ))
499    }
500}
501
502#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
503pub enum SupportedFunctionReturnValue {
504    None,
505    // Top level type is result<_,_> with Err variant
506    FallibleResultErr(WastValWithType),
507    // All other top level types
508    InfallibleOrResultOk(WastValWithType),
509}
510
511#[derive(Debug, thiserror::Error)]
512pub enum ResultParsingError {
513    #[error("result cannot be parsed, multi-value results are not supported")]
514    MultiValue,
515    #[error("result cannot be parsed, {0}")]
516    TypeConversionError(#[from] val_json::type_wrapper::TypeConversionError),
517    #[error("result cannot be parsed, {0}")]
518    ValueConversionError(#[from] val_json::wast_val::WastValConversionError),
519}
520
521impl SupportedFunctionReturnValue {
522    pub fn new<
523        I: ExactSizeIterator<Item = (wasmtime::component::Val, wasmtime::component::Type)>,
524    >(
525        mut iter: I,
526    ) -> Result<Self, ResultParsingError> {
527        if iter.len() == 0 {
528            Ok(Self::None)
529        } else if iter.len() == 1 {
530            let (val, r#type) = iter.next().unwrap();
531            let r#type = TypeWrapper::try_from(r#type)?;
532            let val = WastVal::try_from(val)?;
533            match &val {
534                WastVal::Result(Err(_)) => Ok(Self::FallibleResultErr(WastValWithType {
535                    r#type,
536                    value: val,
537                })),
538                _ => Ok(Self::InfallibleOrResultOk(WastValWithType {
539                    r#type,
540                    value: val,
541                })),
542            }
543        } else {
544            Err(ResultParsingError::MultiValue)
545        }
546    }
547
548    #[cfg(feature = "test")]
549    #[must_use]
550    pub fn fallible_err(&self) -> Option<Option<&WastVal>> {
551        match self {
552            SupportedFunctionReturnValue::FallibleResultErr(WastValWithType {
553                value: WastVal::Result(Err(err)),
554                ..
555            }) => Some(err.as_deref()),
556            _ => None,
557        }
558    }
559
560    #[cfg(feature = "test")]
561    #[must_use]
562    pub fn fallible_ok(&self) -> Option<Option<&WastVal>> {
563        match self {
564            SupportedFunctionReturnValue::InfallibleOrResultOk(WastValWithType {
565                value: WastVal::Result(Ok(ok)),
566                ..
567            }) => Some(ok.as_deref()),
568            _ => None,
569        }
570    }
571
572    #[cfg(feature = "test")]
573    #[must_use]
574    pub fn val_type(&self) -> Option<&TypeWrapper> {
575        match self {
576            SupportedFunctionReturnValue::None => None,
577            SupportedFunctionReturnValue::FallibleResultErr(v)
578            | SupportedFunctionReturnValue::InfallibleOrResultOk(v) => Some(&v.r#type),
579        }
580    }
581
582    #[must_use]
583    pub fn value(&self) -> Option<&WastVal> {
584        match self {
585            SupportedFunctionReturnValue::None => None,
586            SupportedFunctionReturnValue::FallibleResultErr(v)
587            | SupportedFunctionReturnValue::InfallibleOrResultOk(v) => Some(&v.value),
588        }
589    }
590
591    #[must_use]
592    pub fn into_value(self) -> Option<WastVal> {
593        match self {
594            SupportedFunctionReturnValue::None => None,
595            SupportedFunctionReturnValue::FallibleResultErr(v)
596            | SupportedFunctionReturnValue::InfallibleOrResultOk(v) => Some(v.value),
597        }
598    }
599
600    #[must_use]
601    pub fn len(&self) -> usize {
602        match self {
603            SupportedFunctionReturnValue::None => 0,
604            _ => 1,
605        }
606    }
607
608    #[must_use]
609    pub fn is_empty(&self) -> bool {
610        matches!(self, Self::None)
611    }
612
613    #[must_use]
614    pub fn as_pending_state_finished_result(&self) -> PendingStateFinishedResultKind {
615        if let SupportedFunctionReturnValue::FallibleResultErr(_) = self {
616            PendingStateFinishedResultKind(Err(PendingStateFinishedError::FallibleError))
617        } else {
618            PendingStateFinishedResultKind(Ok(()))
619        }
620    }
621}
622
623#[derive(Debug, Clone, PartialEq, Eq)]
624pub struct Params(ParamsInternal);
625
626#[derive(Debug, Clone, PartialEq, Eq)]
627enum ParamsInternal {
628    JsonValues(Vec<Value>),
629    Vals {
630        //TODO: is Arc needed here? Or move upwards?
631        vals: Arc<[wasmtime::component::Val]>,
632    },
633    Empty,
634}
635
636impl Default for Params {
637    fn default() -> Self {
638        Self(ParamsInternal::Empty)
639    }
640}
641
642#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
643pub enum FunctionExtension {
644    Submit,
645    AwaitNext,
646    Schedule,
647}
648
649#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
650pub struct FunctionMetadata {
651    pub ffqn: FunctionFqn,
652    pub parameter_types: ParameterTypes,
653    pub return_type: Option<ReturnType>,
654    pub extension: Option<FunctionExtension>,
655    pub submittable: bool,
656}
657impl Display for FunctionMetadata {
658    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
659        write!(
660            f,
661            "{ffqn}: func{params}",
662            ffqn = self.ffqn,
663            params = self.parameter_types
664        )?;
665        if let Some(return_type) = &self.return_type {
666            write!(f, " -> {return_type}")?;
667        }
668        Ok(())
669    }
670}
671
672pub mod serde_params {
673    use crate::{Params, ParamsInternal};
674    use serde::de::{SeqAccess, Visitor};
675    use serde::ser::SerializeSeq;
676    use serde::{Deserialize, Serialize};
677    use serde_json::Value;
678    use val_json::wast_val::WastVal;
679
680    impl Serialize for Params {
681        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
682        where
683            S: ::serde::Serializer,
684        {
685            match &self.0 {
686                ParamsInternal::Vals { vals } => {
687                    let mut seq = serializer.serialize_seq(Some(vals.len()))?; // size must be equal, checked when constructed.
688                    for val in vals.iter() {
689                        let value = WastVal::try_from(val.clone())
690                            .map_err(|err| serde::ser::Error::custom(err.to_string()))?;
691                        seq.serialize_element(&value)?;
692                    }
693                    seq.end()
694                }
695                ParamsInternal::Empty => serializer.serialize_seq(Some(0))?.end(),
696                ParamsInternal::JsonValues(vec) => {
697                    let mut seq = serializer.serialize_seq(Some(vec.len()))?;
698                    for item in vec {
699                        seq.serialize_element(item)?;
700                    }
701                    seq.end()
702                }
703            }
704        }
705    }
706
707    pub struct VecVisitor;
708
709    impl<'de> Visitor<'de> for VecVisitor {
710        type Value = Vec<Value>;
711
712        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
713            formatter.write_str("a sequence of `Value`")
714        }
715
716        #[inline]
717        fn visit_seq<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
718        where
719            V: SeqAccess<'de>,
720        {
721            let mut vec = Vec::new();
722            while let Some(elem) = visitor.next_element()? {
723                vec.push(elem);
724            }
725            Ok(vec)
726        }
727    }
728
729    impl<'de> Deserialize<'de> for Params {
730        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
731        where
732            D: serde::Deserializer<'de>,
733        {
734            let vec: Vec<Value> = deserializer.deserialize_seq(VecVisitor)?;
735            if vec.is_empty() {
736                Ok(Self(ParamsInternal::Empty))
737            } else {
738                Ok(Self(ParamsInternal::JsonValues(vec)))
739            }
740        }
741    }
742}
743
744#[derive(Debug, thiserror::Error)]
745pub enum ParamsParsingError {
746    #[error("parameters cannot be parsed, cannot convert type of {idx}-th parameter")]
747    ParameterTypeError {
748        idx: usize,
749        err: TypeConversionError,
750    },
751    #[error("parameters cannot be deserialized: {0}")]
752    ParamsDeserializationError(serde_json::Error),
753    #[error("parameter cardinality mismatch, expected: {expected}, specified: {specified}")]
754    ParameterCardinalityMismatch { expected: usize, specified: usize },
755}
756
757impl ParamsParsingError {
758    #[must_use]
759    pub fn detail(&self) -> Option<String> {
760        match self {
761            ParamsParsingError::ParameterTypeError { err, .. } => Some(format!("{err:?}")),
762            ParamsParsingError::ParamsDeserializationError(err) => Some(format!("{err:?}")),
763            ParamsParsingError::ParameterCardinalityMismatch { .. } => None,
764        }
765    }
766}
767
768#[derive(Debug, thiserror::Error)]
769pub enum ParamsFromJsonError {
770    #[error("value must be a json array containing function parameters")]
771    MustBeArray,
772}
773
774impl Params {
775    #[must_use]
776    pub const fn empty() -> Self {
777        Self(ParamsInternal::Empty)
778    }
779
780    #[must_use]
781    pub fn from_wasmtime(vals: Arc<[wasmtime::component::Val]>) -> Self {
782        if vals.is_empty() {
783            Self::empty()
784        } else {
785            Self(ParamsInternal::Vals { vals })
786        }
787    }
788
789    #[must_use]
790    pub fn from_json_values(vec: Vec<Value>) -> Self {
791        if vec.is_empty() {
792            Self::empty()
793        } else {
794            Self(ParamsInternal::JsonValues(vec))
795        }
796    }
797
798    pub fn typecheck<'a>(
799        &self,
800        param_types: impl ExactSizeIterator<Item = &'a TypeWrapper>,
801    ) -> Result<(), ParamsParsingError> {
802        if param_types.len() != self.len() {
803            return Err(ParamsParsingError::ParameterCardinalityMismatch {
804                expected: param_types.len(),
805                specified: self.len(),
806            });
807        }
808        match &self.0 {
809            ParamsInternal::Vals { .. } /* already typechecked */ | ParamsInternal::Empty => {}
810            ParamsInternal::JsonValues(params) => {
811                params::deserialize_values(params, param_types)
812                .map_err(ParamsParsingError::ParamsDeserializationError)?;
813            }
814        }
815        Ok(())
816    }
817
818    pub fn as_vals(
819        &self,
820        param_types: Box<[(String, Type)]>,
821    ) -> Result<Arc<[wasmtime::component::Val]>, ParamsParsingError> {
822        if param_types.len() != self.len() {
823            return Err(ParamsParsingError::ParameterCardinalityMismatch {
824                expected: param_types.len(),
825                specified: self.len(),
826            });
827        }
828        match &self.0 {
829            ParamsInternal::JsonValues(json_vec) => {
830                let param_types = param_types
831                    .into_vec()
832                    .into_iter()
833                    .enumerate()
834                    .map(|(idx, (_param_name, ty))| {
835                        TypeWrapper::try_from(ty).map_err(|err| (idx, err))
836                    })
837                    .collect::<Result<Vec<_>, _>>()
838                    .map_err(|(idx, err)| ParamsParsingError::ParameterTypeError { idx, err })?;
839                Ok(params::deserialize_values(json_vec, param_types.iter())
840                    .map_err(ParamsParsingError::ParamsDeserializationError)?
841                    .into_iter()
842                    .map(Val::from)
843                    .collect())
844            }
845            ParamsInternal::Vals { vals, .. } => Ok(vals.clone()),
846            ParamsInternal::Empty => Ok(Arc::from([])),
847        }
848    }
849
850    #[must_use]
851    pub fn len(&self) -> usize {
852        match &self.0 {
853            ParamsInternal::JsonValues(vec) => vec.len(),
854            ParamsInternal::Vals { vals, .. } => vals.len(),
855            ParamsInternal::Empty => 0,
856        }
857    }
858
859    #[must_use]
860    pub fn is_empty(&self) -> bool {
861        self.len() == 0
862    }
863}
864
865pub mod prefixed_ulid {
866    use arbitrary::Arbitrary;
867    use derivative::Derivative;
868    use serde_with::{DeserializeFromStr, SerializeDisplay};
869    use std::{
870        fmt::{Debug, Display},
871        hash::Hasher,
872        marker::PhantomData,
873        num::ParseIntError,
874        str::FromStr,
875        sync::Arc,
876    };
877    use ulid::Ulid;
878
879    use crate::JoinSetId;
880
881    #[derive(derive_more::Display, SerializeDisplay, DeserializeFromStr, Derivative)]
882    #[display("{}_{ulid}", Self::prefix())]
883    #[derivative(Clone(bound = ""))]
884    #[derivative(Copy(bound = ""))]
885    pub struct PrefixedUlid<T: 'static> {
886        ulid: Ulid,
887        phantom_data: PhantomData<fn(T) -> T>,
888    }
889
890    impl<T> PrefixedUlid<T> {
891        const fn new(ulid: Ulid) -> Self {
892            Self {
893                ulid,
894                phantom_data: PhantomData,
895            }
896        }
897
898        fn prefix() -> &'static str {
899            std::any::type_name::<T>().rsplit("::").next().unwrap()
900        }
901    }
902
903    impl<T> PrefixedUlid<T> {
904        #[must_use]
905        pub fn generate() -> Self {
906            Self::new(Ulid::new())
907        }
908
909        #[must_use]
910        pub const fn from_parts(timestamp_ms: u64, random: u128) -> Self {
911            Self::new(Ulid::from_parts(timestamp_ms, random))
912        }
913
914        #[must_use]
915        pub fn timestamp_part(&self) -> u64 {
916            self.ulid.timestamp_ms()
917        }
918
919        #[must_use]
920        #[expect(clippy::cast_possible_truncation)]
921        pub fn random_part(&self) -> u64 {
922            self.ulid.random() as u64
923        }
924    }
925
926    #[derive(Debug, thiserror::Error)]
927    pub enum PrefixedUlidParseError {
928        #[error("wrong prefix in `{input}`, expected prefix `{expected}`")]
929        WrongPrefix { input: String, expected: String },
930        #[error("cannot parse ULID suffix from `{input}`")]
931        CannotParseUlid { input: String },
932    }
933
934    mod impls {
935        use super::{PrefixedUlid, PrefixedUlidParseError, Ulid};
936        use std::{fmt::Debug, fmt::Display, hash::Hash, marker::PhantomData, str::FromStr};
937
938        impl<T> FromStr for PrefixedUlid<T> {
939            type Err = PrefixedUlidParseError;
940
941            fn from_str(input: &str) -> Result<Self, Self::Err> {
942                let prefix = Self::prefix();
943                let mut input_chars = input.chars();
944                for exp in prefix.chars() {
945                    if input_chars.next() != Some(exp) {
946                        return Err(PrefixedUlidParseError::WrongPrefix {
947                            input: input.to_string(),
948                            expected: format!("{prefix}_"),
949                        });
950                    }
951                }
952                if input_chars.next() != Some('_') {
953                    return Err(PrefixedUlidParseError::WrongPrefix {
954                        input: input.to_string(),
955                        expected: format!("{prefix}_"),
956                    });
957                }
958                let Ok(ulid) = Ulid::from_string(input_chars.as_str()) else {
959                    return Err(PrefixedUlidParseError::CannotParseUlid {
960                        input: input.to_string(),
961                    });
962                };
963                Ok(Self {
964                    ulid,
965                    phantom_data: PhantomData,
966                })
967            }
968        }
969
970        impl<T> Debug for PrefixedUlid<T> {
971            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
972                Display::fmt(&self, f)
973            }
974        }
975
976        impl<T> Hash for PrefixedUlid<T> {
977            fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
978                Self::prefix().hash(state);
979                self.ulid.hash(state);
980                self.phantom_data.hash(state);
981            }
982        }
983
984        impl<T> PartialEq for PrefixedUlid<T> {
985            fn eq(&self, other: &Self) -> bool {
986                self.ulid == other.ulid
987            }
988        }
989
990        impl<T> Eq for PrefixedUlid<T> {}
991
992        impl<T> PartialOrd for PrefixedUlid<T> {
993            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
994                Some(self.cmp(other))
995            }
996        }
997
998        impl<T> Ord for PrefixedUlid<T> {
999            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1000                self.ulid.cmp(&other.ulid)
1001            }
1002        }
1003    }
1004
1005    pub mod prefix {
1006        pub struct E;
1007        pub struct Exr;
1008        pub struct Run;
1009        pub struct Delay;
1010    }
1011
1012    pub type ExecutorId = PrefixedUlid<prefix::Exr>;
1013    pub type ExecutionIdTopLevel = PrefixedUlid<prefix::E>;
1014    pub type RunId = PrefixedUlid<prefix::Run>;
1015    pub type DelayId = PrefixedUlid<prefix::Delay>;
1016
1017    impl<'a, T> Arbitrary<'a> for PrefixedUlid<T> {
1018        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1019            Ok(Self::new(ulid::Ulid::from_parts(
1020                u.arbitrary()?,
1021                u.arbitrary()?,
1022            )))
1023        }
1024    }
1025
1026    #[derive(Hash, PartialEq, Eq, PartialOrd, Ord, SerializeDisplay, DeserializeFromStr, Clone)]
1027    pub enum ExecutionId {
1028        TopLevel(ExecutionIdTopLevel),
1029        Derived(ExecutionIdDerived),
1030    }
1031
1032    #[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, SerializeDisplay, DeserializeFromStr)]
1033    pub struct ExecutionIdDerived {
1034        top_level: ExecutionIdTopLevel,
1035        infix: Arc<str>,
1036        idx: u64,
1037    }
1038    impl ExecutionIdDerived {
1039        #[must_use]
1040        pub fn get_incremented(&self) -> Self {
1041            self.get_incremented_by(1)
1042        }
1043        #[must_use]
1044        pub fn get_incremented_by(&self, count: u64) -> Self {
1045            ExecutionIdDerived {
1046                top_level: self.top_level,
1047                infix: self.infix.clone(),
1048                idx: self.idx + count,
1049            }
1050        }
1051        #[must_use]
1052        pub fn next_level(&self, join_set_id: &JoinSetId) -> ExecutionIdDerived {
1053            let ExecutionIdDerived {
1054                top_level,
1055                infix,
1056                idx,
1057            } = self;
1058            let infix = Arc::from(format!(
1059                "{infix}{EXECUTION_ID_INFIX}{idx}{EXECUTION_ID_INFIX}{join_set_id}"
1060            ));
1061            ExecutionIdDerived {
1062                top_level: *top_level,
1063                infix,
1064                idx: 0,
1065            }
1066        }
1067        fn display_or_debug(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1068            let ExecutionIdDerived {
1069                top_level,
1070                infix,
1071                idx,
1072            } = self;
1073            write!(
1074                f,
1075                "{top_level}{EXECUTION_ID_INFIX}{infix}{EXECUTION_ID_INFIX}{idx}"
1076            )
1077        }
1078    }
1079    impl Debug for ExecutionIdDerived {
1080        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1081            self.display_or_debug(f)
1082        }
1083    }
1084    impl Display for ExecutionIdDerived {
1085        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1086            self.display_or_debug(f)
1087        }
1088    }
1089    impl FromStr for ExecutionIdDerived {
1090        type Err = ExecutionIdDerivedParseError;
1091
1092        fn from_str(input: &str) -> Result<Self, Self::Err> {
1093            if let Some((prefix, suffix)) = input.split_once(EXECUTION_ID_INFIX) {
1094                let top_level = PrefixedUlid::from_str(prefix)
1095                    .map_err(ExecutionIdDerivedParseError::PrefixedUlidParseError)?;
1096                let Some((infix, idx)) = suffix.rsplit_once(EXECUTION_ID_INFIX) else {
1097                    return Err(ExecutionIdDerivedParseError::SecondDelimiterNotFound);
1098                };
1099                let infix = Arc::from(infix);
1100                let idx =
1101                    u64::from_str(idx).map_err(ExecutionIdDerivedParseError::ParseIndexError)?;
1102                Ok(ExecutionIdDerived {
1103                    top_level,
1104                    infix,
1105                    idx,
1106                })
1107            } else {
1108                Err(ExecutionIdDerivedParseError::FirstDelimiterNotFound)
1109            }
1110        }
1111    }
1112    impl<'a> Arbitrary<'a> for ExecutionIdDerived {
1113        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1114            let top_level = ExecutionId::TopLevel(ExecutionIdTopLevel::arbitrary(u)?);
1115            let join_set_id = JoinSetId::arbitrary(u)?;
1116            Ok(top_level.next_level(&join_set_id))
1117        }
1118    }
1119    #[derive(Debug, thiserror::Error)]
1120    pub enum ExecutionIdDerivedParseError {
1121        #[error(transparent)]
1122        PrefixedUlidParseError(PrefixedUlidParseError),
1123        #[error(
1124            "cannot parse derived execution id - first delimiter `{EXECUTION_ID_INFIX}` not found"
1125        )]
1126        FirstDelimiterNotFound,
1127        #[error(
1128            "cannot parse derived execution id - second delimiter `{EXECUTION_ID_INFIX}` not found"
1129        )]
1130        SecondDelimiterNotFound,
1131        #[error("cannot parse derived execution id - last suffix must be a number")]
1132        ParseIndexError(ParseIntError),
1133    }
1134
1135    impl ExecutionId {
1136        #[must_use]
1137        pub fn generate() -> Self {
1138            ExecutionId::TopLevel(PrefixedUlid::generate())
1139        }
1140
1141        #[must_use]
1142        pub fn get_top_level(&self) -> ExecutionIdTopLevel {
1143            match &self {
1144                ExecutionId::TopLevel(prefixed_ulid) => *prefixed_ulid,
1145                ExecutionId::Derived(ExecutionIdDerived { top_level, .. }) => *top_level,
1146            }
1147        }
1148
1149        #[must_use]
1150        pub fn timestamp_part(&self) -> u64 {
1151            self.get_top_level().timestamp_part()
1152        }
1153
1154        #[must_use]
1155        pub fn random_seed(&self) -> u64 {
1156            let mut hasher = fxhash::FxHasher::default();
1157            hasher.write_u64(self.get_top_level().random_part());
1158            hasher.write_u64(self.timestamp_part());
1159            if let ExecutionId::Derived(ExecutionIdDerived {
1160                top_level: _,
1161                infix,
1162                idx,
1163            }) = self
1164            {
1165                hasher.write(infix.as_bytes());
1166                hasher.write_u64(*idx);
1167            }
1168            hasher.finish()
1169        }
1170
1171        #[must_use]
1172        pub const fn from_parts(timestamp_ms: u64, random_part: u128) -> Self {
1173            ExecutionId::TopLevel(ExecutionIdTopLevel::from_parts(timestamp_ms, random_part))
1174        }
1175
1176        #[must_use]
1177        pub fn next_level(&self, join_set_id: &JoinSetId) -> ExecutionIdDerived {
1178            match &self {
1179                ExecutionId::TopLevel(top_level) => ExecutionIdDerived {
1180                    top_level: *top_level,
1181                    infix: Arc::from(join_set_id.to_string()),
1182                    idx: 0,
1183                },
1184                ExecutionId::Derived(derived) => derived.next_level(join_set_id),
1185            }
1186        }
1187
1188        fn display_or_debug(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1189            match &self {
1190                ExecutionId::TopLevel(top_level) => Display::fmt(top_level, f),
1191                ExecutionId::Derived(derived) => Display::fmt(derived, f),
1192            }
1193        }
1194    }
1195
1196    const EXECUTION_ID_INFIX: char = '.';
1197
1198    #[derive(Debug, thiserror::Error)]
1199    pub enum ExecutionIdParseError {
1200        #[error(transparent)]
1201        PrefixedUlidParseError(#[from] PrefixedUlidParseError),
1202        #[error(
1203            "cannot parse derived execution id - first delimiter `{EXECUTION_ID_INFIX}` not found"
1204        )]
1205        FirstDelimiterNotFound,
1206        #[error(
1207            "cannot parse derived execution id - second delimiter `{EXECUTION_ID_INFIX}` not found"
1208        )]
1209        SecondDelimiterNotFound,
1210        #[error("cannot parse derived execution id - last suffix must be a number")]
1211        ParseIndexError(#[from] ParseIntError),
1212    }
1213
1214    impl FromStr for ExecutionId {
1215        type Err = ExecutionIdParseError;
1216
1217        fn from_str(input: &str) -> Result<Self, Self::Err> {
1218            if input.contains(EXECUTION_ID_INFIX) {
1219                ExecutionIdDerived::from_str(input)
1220                    .map(ExecutionId::Derived)
1221                    .map_err(|err| match err {
1222                        ExecutionIdDerivedParseError::FirstDelimiterNotFound => {
1223                            unreachable!("first delimiter checked")
1224                        }
1225                        ExecutionIdDerivedParseError::SecondDelimiterNotFound => {
1226                            ExecutionIdParseError::SecondDelimiterNotFound
1227                        }
1228                        ExecutionIdDerivedParseError::PrefixedUlidParseError(err) => {
1229                            ExecutionIdParseError::PrefixedUlidParseError(err)
1230                        }
1231                        ExecutionIdDerivedParseError::ParseIndexError(err) => {
1232                            ExecutionIdParseError::ParseIndexError(err)
1233                        }
1234                    })
1235            } else {
1236                Ok(ExecutionId::TopLevel(PrefixedUlid::from_str(input)?))
1237            }
1238        }
1239    }
1240
1241    impl Debug for ExecutionId {
1242        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1243            self.display_or_debug(f)
1244        }
1245    }
1246
1247    impl Display for ExecutionId {
1248        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1249            self.display_or_debug(f)
1250        }
1251    }
1252
1253    impl<'a> Arbitrary<'a> for ExecutionId {
1254        fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1255            Ok(ExecutionId::TopLevel(PrefixedUlid::arbitrary(u)?))
1256        }
1257    }
1258}
1259
1260#[derive(
1261    Debug,
1262    Clone,
1263    PartialEq,
1264    Eq,
1265    Hash,
1266    derive_more::Display,
1267    serde_with::SerializeDisplay,
1268    serde_with::DeserializeFromStr,
1269)]
1270#[non_exhaustive] // force using the constructor as much as possible due to validation
1271#[display("{kind}{JOIN_SET_ID_INFIX}{name}")]
1272pub struct JoinSetId {
1273    pub kind: JoinSetKind,
1274    pub name: StrVariant,
1275}
1276
1277impl JoinSetId {
1278    pub fn new(kind: JoinSetKind, name: StrVariant) -> Result<Self, InvalidNameError<JoinSetId>> {
1279        Ok(Self {
1280            kind,
1281            name: check_name(name, CHARSET_EXTRA_JSON_SET)?,
1282        })
1283    }
1284}
1285const CHARSET_JOIN_SET_NAME: &str =
1286    const_format::concatcp!(CHARSET_ALPHANUMERIC, CHARSET_EXTRA_JSON_SET);
1287
1288pub const CHARSET_ALPHANUMERIC: &str =
1289    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
1290pub fn random_string(
1291    rng: &mut rand::rngs::StdRng,
1292    min_length: u16,
1293    max_length_exclusive: u16,
1294    charset: &'static str,
1295) -> String {
1296    let length_inclusive = rand::Rng::gen_range(rng, min_length..max_length_exclusive);
1297    (0..=length_inclusive)
1298        .map(|_| {
1299            let idx = rand::Rng::gen_range(rng, 0..charset.len());
1300            charset.chars().nth(idx).expect("idx is < charset.len()")
1301        })
1302        .collect()
1303}
1304
1305#[derive(
1306    Debug,
1307    Clone,
1308    Copy,
1309    PartialEq,
1310    Eq,
1311    Hash,
1312    derive_more::Display,
1313    Serialize,
1314    Deserialize,
1315    strum::EnumIter,
1316    Arbitrary,
1317)]
1318#[display("{}", self.as_code())]
1319pub enum JoinSetKind {
1320    OneOff,
1321    Named,
1322    Generated,
1323}
1324impl JoinSetKind {
1325    fn as_code(&self) -> &'static str {
1326        match self {
1327            JoinSetKind::OneOff => "o",
1328            JoinSetKind::Named => "n",
1329            JoinSetKind::Generated => "g",
1330        }
1331    }
1332}
1333impl FromStr for JoinSetKind {
1334    type Err = &'static str;
1335    fn from_str(s: &str) -> Result<Self, Self::Err> {
1336        use strum::IntoEnumIterator;
1337        Self::iter()
1338            .find(|variant| s == variant.as_code())
1339            .ok_or("unknown join set kind")
1340    }
1341}
1342
1343const JOIN_SET_ID_INFIX: char = ':';
1344const CHARSET_EXTRA_JSON_SET: &str = "_-/";
1345
1346impl FromStr for JoinSetId {
1347    type Err = JoinSetIdParseError;
1348
1349    fn from_str(input: &str) -> Result<Self, Self::Err> {
1350        let Some((kind, name)) = input.split_once(JOIN_SET_ID_INFIX) else {
1351            return Err(JoinSetIdParseError::WrongParts);
1352        };
1353        let kind = kind
1354            .parse()
1355            .map_err(JoinSetIdParseError::JoinSetKindParseError)?;
1356        Ok(JoinSetId::new(kind, StrVariant::from(name.to_string()))?)
1357    }
1358}
1359
1360#[derive(Debug, thiserror::Error)]
1361pub enum JoinSetIdParseError {
1362    #[error("join set must consist of three parts separated by {JOIN_SET_ID_INFIX} ")]
1363    WrongParts,
1364    #[error("cannot parse join set id's execution id - {0}")]
1365    ExecutionIdParseError(#[from] ExecutionIdParseError),
1366    #[error("cannot parse join set kind - {0}")]
1367    JoinSetKindParseError(&'static str),
1368    #[error("cannot parse join set id - {0}")]
1369    InvalidName(#[from] InvalidNameError<JoinSetId>),
1370}
1371
1372impl<'a> Arbitrary<'a> for JoinSetId {
1373    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1374        let name: String = {
1375            let length_inclusive = u.int_in_range(0..=10).unwrap();
1376            (0..=length_inclusive)
1377                .map(|_| {
1378                    let idx = u.choose_index(CHARSET_JOIN_SET_NAME.len()).unwrap();
1379                    CHARSET_JOIN_SET_NAME
1380                        .chars()
1381                        .nth(idx)
1382                        .expect("idx is < charset.len()")
1383                })
1384                .collect()
1385        };
1386
1387        Ok(JoinSetId::new(JoinSetKind::Generated, StrVariant::from(name)).unwrap())
1388    }
1389}
1390
1391#[derive(
1392    Debug,
1393    Clone,
1394    Copy,
1395    strum::Display,
1396    PartialEq,
1397    Eq,
1398    strum::EnumString,
1399    Hash,
1400    serde_with::SerializeDisplay,
1401    serde_with::DeserializeFromStr,
1402)]
1403#[strum(serialize_all = "snake_case")]
1404pub enum ComponentType {
1405    ActivityWasm,
1406    Workflow,
1407    WebhookEndpoint,
1408}
1409
1410#[derive(
1411    derive_more::Debug,
1412    Clone,
1413    PartialEq,
1414    Eq,
1415    Hash,
1416    serde_with::SerializeDisplay,
1417    serde_with::DeserializeFromStr,
1418    derive_more::Display,
1419)]
1420#[display("{component_type}:{name}:{hash}")]
1421#[debug("{}", self)]
1422#[non_exhaustive] // force using the constructor as much as possible due to validation
1423pub struct ComponentId {
1424    pub component_type: ComponentType,
1425    pub name: StrVariant,
1426    pub hash: StrVariant,
1427}
1428impl ComponentId {
1429    pub fn new(
1430        component_type: ComponentType,
1431        name: StrVariant,
1432        hash: StrVariant,
1433    ) -> Result<Self, InvalidNameError<Self>> {
1434        Ok(Self {
1435            component_type,
1436            name: check_name(name, "_")?,
1437            hash,
1438        })
1439    }
1440
1441    #[must_use]
1442    pub const fn dummy_activity() -> Self {
1443        Self {
1444            component_type: ComponentType::ActivityWasm,
1445            name: StrVariant::empty(),
1446            hash: StrVariant::empty(),
1447        }
1448    }
1449
1450    #[must_use]
1451    pub const fn dummy_workflow() -> ComponentId {
1452        ComponentId {
1453            component_type: ComponentType::Workflow,
1454            name: StrVariant::empty(),
1455            hash: StrVariant::empty(),
1456        }
1457    }
1458}
1459
1460pub fn check_name<T>(
1461    name: StrVariant,
1462    special: &'static str,
1463) -> Result<StrVariant, InvalidNameError<T>> {
1464    if let Some(invalid) = name
1465        .as_ref()
1466        .chars()
1467        .find(|c| !c.is_ascii_alphanumeric() && !special.contains(*c))
1468    {
1469        Err(InvalidNameError::<T> {
1470            invalid,
1471            name: name.as_ref().to_string(),
1472            special,
1473            phantom_data: PhantomData,
1474        })
1475    } else {
1476        Ok(name)
1477    }
1478}
1479#[derive(Debug, thiserror::Error)]
1480#[error(
1481    "name of {} `{name}` contains invalid character `{invalid}`, must only contain alphanumeric characters and following characters {special}",
1482    std::any::type_name::<T>().rsplit("::").next().unwrap()
1483)]
1484pub struct InvalidNameError<T> {
1485    invalid: char,
1486    name: String,
1487    special: &'static str,
1488    phantom_data: PhantomData<T>,
1489}
1490
1491#[derive(Debug, thiserror::Error)]
1492pub enum ConfigIdParseError {
1493    #[error("cannot parse ComponentConfigHash - delimiter ':' not found")]
1494    DelimiterNotFound,
1495    #[error("cannot parse prefix of ComponentConfigHash - {0}")]
1496    ComponentTypeParseError(#[from] strum::ParseError),
1497    #[error("cannot parse suffix of ComponentConfigHash - {0}")]
1498    ContentDigestParseErrror(#[from] DigestParseErrror),
1499}
1500
1501impl FromStr for ComponentId {
1502    type Err = ConfigIdParseError;
1503
1504    fn from_str(input: &str) -> Result<Self, Self::Err> {
1505        let (component_type, input) = input.split_once(':').ok_or(Self::Err::DelimiterNotFound)?;
1506        let (name, hash) = input.split_once(':').ok_or(Self::Err::DelimiterNotFound)?;
1507        let component_type = component_type.parse()?;
1508        Ok(Self {
1509            component_type,
1510            name: StrVariant::from(name.to_string()),
1511            hash: StrVariant::from(hash.to_string()),
1512        })
1513    }
1514}
1515
1516#[derive(
1517    Debug,
1518    Clone,
1519    Copy,
1520    strum::Display,
1521    strum::EnumString,
1522    PartialEq,
1523    Eq,
1524    Hash,
1525    serde_with::SerializeDisplay,
1526    serde_with::DeserializeFromStr,
1527)]
1528#[strum(serialize_all = "snake_case")]
1529pub enum HashType {
1530    Sha256,
1531}
1532
1533#[derive(
1534    Debug,
1535    Clone,
1536    derive_more::Display,
1537    derive_more::FromStr,
1538    derive_more::Deref,
1539    PartialEq,
1540    Eq,
1541    Hash,
1542    serde_with::SerializeDisplay,
1543    serde_with::DeserializeFromStr,
1544)]
1545pub struct ContentDigest(Digest);
1546
1547impl ContentDigest {
1548    #[must_use]
1549    pub fn new(hash_type: HashType, hash_base16: String) -> Self {
1550        Self(Digest::new(hash_type, hash_base16))
1551    }
1552}
1553
1554#[derive(
1555    Debug,
1556    Clone,
1557    derive_more::Display,
1558    PartialEq,
1559    Eq,
1560    Hash,
1561    serde_with::SerializeDisplay,
1562    serde_with::DeserializeFromStr,
1563)]
1564#[display("{hash_type}:{hash_base16}")]
1565pub struct Digest {
1566    hash_type: HashType,
1567    hash_base16: StrVariant,
1568}
1569impl Digest {
1570    #[must_use]
1571    pub fn new(hash_type: HashType, hash_base16: String) -> Self {
1572        Self {
1573            hash_type,
1574            hash_base16: StrVariant::Arc(Arc::from(hash_base16)),
1575        }
1576    }
1577
1578    #[must_use]
1579    pub fn hash_type(&self) -> HashType {
1580        self.hash_type
1581    }
1582
1583    #[must_use]
1584    pub fn digest_base16(&self) -> &str {
1585        &self.hash_base16
1586    }
1587}
1588
1589#[derive(Debug, thiserror::Error)]
1590pub enum DigestParseErrror {
1591    #[error("cannot parse ContentDigest - delimiter ':' not found")]
1592    DelimiterNotFound,
1593    #[error("cannot parse ContentDigest - invalid prefix `{hash_type}`")]
1594    TypeParseError { hash_type: String },
1595    #[error("cannot parse ContentDigest - invalid suffix length, expected 64 hex digits, got {0}")]
1596    SuffixLength(usize),
1597    #[error(
1598        "cannot parse ContentDigest - suffix must be hex-encoded, got invalid character `{0}`"
1599    )]
1600    SuffixInvalid(char),
1601}
1602
1603impl FromStr for Digest {
1604    type Err = DigestParseErrror;
1605
1606    fn from_str(input: &str) -> Result<Self, Self::Err> {
1607        let (hash_type, hash_base16) = input.split_once(':').ok_or(Self::Err::DelimiterNotFound)?;
1608        let hash_type =
1609            HashType::from_str(hash_type).map_err(|_err| Self::Err::TypeParseError {
1610                hash_type: hash_type.to_string(),
1611            })?;
1612        if hash_base16.len() != 64 {
1613            return Err(Self::Err::SuffixLength(hash_base16.len()));
1614        }
1615        if let Some(invalid) = hash_base16.chars().find(|c| !c.is_ascii_hexdigit()) {
1616            return Err(Self::Err::SuffixInvalid(invalid));
1617        }
1618        Ok(Self {
1619            hash_type,
1620            hash_base16: StrVariant::Arc(Arc::from(hash_base16)),
1621        })
1622    }
1623}
1624
1625#[derive(
1626    Debug, Clone, serde::Serialize, serde::Deserialize, Derivative, Eq, derive_more::Display,
1627)]
1628#[derivative(PartialEq)]
1629#[display("{wit_type}")]
1630pub struct ReturnType {
1631    pub type_wrapper: TypeWrapper,
1632    #[derivative(PartialEq = "ignore")]
1633    pub wit_type: StrVariant,
1634}
1635
1636#[derive(
1637    Debug, Clone, serde::Serialize, serde::Deserialize, Derivative, Eq, derive_more::Display,
1638)]
1639#[derivative(PartialEq)]
1640#[display("{name}: {wit_type}")]
1641pub struct ParameterType {
1642    pub type_wrapper: TypeWrapper,
1643    #[derivative(PartialEq = "ignore")]
1644    pub name: StrVariant,
1645    #[derivative(PartialEq = "ignore")]
1646    pub wit_type: StrVariant,
1647}
1648
1649#[derive(
1650    Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Default, derive_more::Deref,
1651)]
1652pub struct ParameterTypes(pub Vec<ParameterType>);
1653
1654impl Debug for ParameterTypes {
1655    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1656        write!(f, "(")?;
1657        let mut iter = self.0.iter().peekable();
1658        while let Some(p) = iter.next() {
1659            write!(f, "{p:?}")?;
1660            if iter.peek().is_some() {
1661                write!(f, ", ")?;
1662            }
1663        }
1664        write!(f, ")")
1665    }
1666}
1667
1668impl Display for ParameterTypes {
1669    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1670        write!(f, "(")?;
1671        let mut iter = self.0.iter().peekable();
1672        while let Some(p) = iter.next() {
1673            write!(f, "{p}")?;
1674            if iter.peek().is_some() {
1675                write!(f, ", ")?;
1676            }
1677        }
1678        write!(f, ")")
1679    }
1680}
1681
1682#[derive(Debug, Clone)]
1683pub struct PackageIfcFns {
1684    pub ifc_fqn: IfcFqnName,
1685    pub extension: bool,
1686    pub fns: IndexMap<FnName, FunctionMetadata>,
1687}
1688
1689#[derive(Debug, Clone, Copy)]
1690pub struct ComponentRetryConfig {
1691    pub max_retries: u32,
1692    pub retry_exp_backoff: Duration,
1693}
1694
1695/// Implementation must not return `-obelisk-ext` suffix in any package name, nor `obelisk` namespace.
1696#[async_trait]
1697pub trait FunctionRegistry: Send + Sync {
1698    async fn get_by_exported_function(
1699        &self,
1700        ffqn: &FunctionFqn,
1701    ) -> Option<(FunctionMetadata, ComponentId, ComponentRetryConfig)>;
1702
1703    fn all_exports(&self) -> &[PackageIfcFns];
1704}
1705
1706#[derive(Debug, Default, Clone, Serialize, Deserialize, derive_more::Display, PartialEq, Eq)]
1707#[display("{_0:?}")]
1708pub struct ExecutionMetadata(Option<hashbrown::HashMap<String, String>>);
1709
1710impl ExecutionMetadata {
1711    const LINKED_KEY: &str = "obelisk-tracing-linked";
1712    #[must_use]
1713    pub const fn empty() -> Self {
1714        // Remove `Optional` when const hashmap creation is allowed - https://github.com/rust-lang/rust/issues/123197
1715        Self(None)
1716    }
1717
1718    #[must_use]
1719    pub fn from_parent_span(less_specific: &Span) -> Self {
1720        ExecutionMetadata::create(less_specific, false)
1721    }
1722
1723    #[must_use]
1724    pub fn from_linked_span(less_specific: &Span) -> Self {
1725        ExecutionMetadata::create(less_specific, true)
1726    }
1727
1728    /// Attempt to use `Span::current()` to fill the trace and parent span.
1729    /// If that fails, which can happen due to interference with e.g.
1730    /// the stdout layer of the subscriber, use the `span` which is guaranteed
1731    /// to be on info level.
1732    #[must_use]
1733    #[expect(clippy::items_after_statements)]
1734    fn create(span: &Span, link_marker: bool) -> Self {
1735        use tracing_opentelemetry::OpenTelemetrySpanExt as _;
1736        let mut metadata = Self(Some(hashbrown::HashMap::default()));
1737        let mut metadata_view = ExecutionMetadataInjectorView {
1738            metadata: &mut metadata,
1739        };
1740        // inject the current context through the amqp headers
1741        fn inject(s: &Span, metadata_view: &mut ExecutionMetadataInjectorView) {
1742            opentelemetry::global::get_text_map_propagator(|propagator| {
1743                propagator.inject_context(&s.context(), metadata_view);
1744            });
1745        }
1746        inject(&Span::current(), &mut metadata_view);
1747        if metadata_view.is_empty() {
1748            // The subscriber sent us a current span that is actually disabled
1749            inject(span, &mut metadata_view);
1750        }
1751        if link_marker {
1752            metadata_view.set(Self::LINKED_KEY, String::new());
1753        }
1754        metadata
1755    }
1756
1757    pub fn enrich(&self, span: &Span) {
1758        use opentelemetry::trace::TraceContextExt as _;
1759        use tracing_opentelemetry::OpenTelemetrySpanExt as _;
1760
1761        let metadata_view = ExecutionMetadataExtractorView { metadata: self };
1762        let otel_context = opentelemetry::global::get_text_map_propagator(|propagator| {
1763            propagator.extract(&metadata_view)
1764        });
1765        if metadata_view.get(Self::LINKED_KEY).is_some() {
1766            let linked_span_context = otel_context.span().span_context().clone();
1767            span.add_link(linked_span_context);
1768        } else {
1769            span.set_parent(otel_context);
1770        }
1771    }
1772}
1773
1774struct ExecutionMetadataInjectorView<'a> {
1775    metadata: &'a mut ExecutionMetadata,
1776}
1777
1778impl ExecutionMetadataInjectorView<'_> {
1779    fn is_empty(&self) -> bool {
1780        self.metadata
1781            .0
1782            .as_ref()
1783            .is_some_and(hashbrown::HashMap::is_empty)
1784    }
1785}
1786
1787impl opentelemetry::propagation::Injector for ExecutionMetadataInjectorView<'_> {
1788    fn set(&mut self, key: &str, value: String) {
1789        let key = format!("tracing:{key}");
1790        let map = if let Some(map) = self.metadata.0.as_mut() {
1791            map
1792        } else {
1793            self.metadata.0 = Some(hashbrown::HashMap::new());
1794            assert_matches!(&mut self.metadata.0, Some(map) => map)
1795        };
1796        map.insert(key, value);
1797    }
1798}
1799
1800struct ExecutionMetadataExtractorView<'a> {
1801    metadata: &'a ExecutionMetadata,
1802}
1803
1804impl opentelemetry::propagation::Extractor for ExecutionMetadataExtractorView<'_> {
1805    fn get(&self, key: &str) -> Option<&str> {
1806        self.metadata
1807            .0
1808            .as_ref()
1809            .and_then(|map| map.get(&format!("tracing:{key}")))
1810            .map(std::string::String::as_str)
1811    }
1812
1813    fn keys(&self) -> Vec<&str> {
1814        match &self.metadata.0.as_ref() {
1815            Some(map) => map
1816                .keys()
1817                .filter_map(|key| key.strip_prefix("tracing:"))
1818                .collect(),
1819            None => vec![],
1820        }
1821    }
1822}
1823
1824#[cfg(test)]
1825mod tests {
1826
1827    use crate::{prefixed_ulid::ExecutorId, ExecutionId, JoinSetId, JoinSetKind, StrVariant};
1828    use std::{
1829        hash::{DefaultHasher, Hash, Hasher},
1830        str::FromStr,
1831        sync::Arc,
1832    };
1833
1834    #[cfg(madsim)]
1835    #[test]
1836    fn ulid_generation_should_be_deterministic() {
1837        let mut builder_a = madsim::runtime::Builder::from_env();
1838        builder_a.check = true;
1839
1840        let mut builder_b = madsim::runtime::Builder::from_env(); // Builder: Clone would be useful
1841        builder_b.check = true;
1842        builder_b.seed = builder_a.seed;
1843
1844        assert_eq!(
1845            builder_a.run(|| async { ulid::Ulid::new() }),
1846            builder_b.run(|| async { ulid::Ulid::new() })
1847        );
1848    }
1849
1850    #[test]
1851    fn ulid_parsing() {
1852        let generated = ExecutorId::generate();
1853        let str = generated.to_string();
1854        let parsed = str.parse().unwrap();
1855        assert_eq!(generated, parsed);
1856    }
1857
1858    #[test]
1859    fn execution_id_parsing_top_level() {
1860        let generated = ExecutionId::generate();
1861        let str = generated.to_string();
1862        let parsed = str.parse().unwrap();
1863        assert_eq!(generated, parsed);
1864    }
1865
1866    #[test]
1867    fn execution_id_with_one_level_should_parse() {
1868        let top_level = ExecutionId::generate();
1869        let join_set_id = JoinSetId::new(JoinSetKind::Named, StrVariant::Static("name")).unwrap();
1870        let first_child = ExecutionId::Derived(top_level.next_level(&join_set_id));
1871        let ser = first_child.to_string();
1872        assert_eq!(format!("{top_level}.n:name.0"), ser);
1873        let parsed = ExecutionId::from_str(&ser).unwrap();
1874        assert_eq!(first_child, parsed);
1875    }
1876
1877    #[test]
1878    fn execution_id_increment_twice() {
1879        let top_level = ExecutionId::generate();
1880        let join_set_id = JoinSetId::new(JoinSetKind::Named, StrVariant::Static("name")).unwrap();
1881        let first_child = top_level.next_level(&join_set_id);
1882        let second_child = ExecutionId::Derived(first_child.get_incremented());
1883        let ser = second_child.to_string();
1884        assert_eq!(format!("{top_level}.n:name.1"), ser);
1885        let parsed = ExecutionId::from_str(&ser).unwrap();
1886        assert_eq!(second_child, parsed);
1887    }
1888
1889    #[test]
1890    fn execution_id_next_level_twice() {
1891        let top_level = ExecutionId::generate();
1892        let join_set_id_outer =
1893            JoinSetId::new(JoinSetKind::Generated, StrVariant::Static("gg")).unwrap();
1894        let join_set_id_inner =
1895            JoinSetId::new(JoinSetKind::OneOff, StrVariant::Static("oo")).unwrap();
1896        let execution_id = ExecutionId::Derived(
1897            top_level
1898                .next_level(&join_set_id_outer)
1899                .get_incremented()
1900                .next_level(&join_set_id_inner)
1901                .get_incremented(),
1902        );
1903        let ser = execution_id.to_string();
1904        assert_eq!(format!("{top_level}.g:gg.1.o:oo.1"), ser);
1905        let parsed = ExecutionId::from_str(&ser).unwrap();
1906        assert_eq!(execution_id, parsed);
1907    }
1908
1909    #[test]
1910    fn execution_id_hash_should_be_stable() {
1911        let parent = ExecutionId::from_parts(1, 2);
1912        let join_set_id = JoinSetId::new(JoinSetKind::Named, StrVariant::Static("name")).unwrap();
1913        let sibling_1 = parent.next_level(&join_set_id);
1914        let sibling_2 = ExecutionId::Derived(sibling_1.get_incremented());
1915        let sibling_1 = ExecutionId::Derived(sibling_1);
1916        let join_set_id_inner =
1917            JoinSetId::new(JoinSetKind::OneOff, StrVariant::Static("oo")).unwrap();
1918        let child =
1919            ExecutionId::Derived(sibling_1.next_level(&join_set_id_inner).get_incremented());
1920        let parent = parent.random_seed();
1921        let sibling_1 = sibling_1.random_seed();
1922        let sibling_2 = sibling_2.random_seed();
1923        let child = child.random_seed();
1924        let vec = vec![parent, sibling_1, sibling_2, child];
1925        insta::assert_debug_snapshot!(vec);
1926        // check that every hash is unique
1927        let set: hashbrown::HashSet<_> = vec.into_iter().collect();
1928        assert_eq!(4, set.len());
1929    }
1930
1931    #[test]
1932    fn hash_of_str_variants_should_be_equal() {
1933        let input = "foo";
1934        let left = StrVariant::Arc(Arc::from(input));
1935        let right = StrVariant::Static(input);
1936        assert_eq!(left, right);
1937        let mut left_hasher = DefaultHasher::new();
1938        left.hash(&mut left_hasher);
1939        let mut right_hasher = DefaultHasher::new();
1940        right.hash(&mut right_hasher);
1941        let left_hasher = left_hasher.finish();
1942        let right_hasher = right_hasher.finish();
1943        println!("left: {left_hasher:x}, right: {right_hasher:x}");
1944        assert_eq!(left_hasher, right_hasher);
1945    }
1946
1947    #[cfg(madsim)]
1948    #[tokio::test]
1949    async fn join_set_serde_should_be_consistent() {
1950        use crate::{JoinSetId, JoinSetKind};
1951        use strum::IntoEnumIterator;
1952        for kind in JoinSetKind::iter() {
1953            let join_set_id = JoinSetId::new(kind, StrVariant::from("name")).unwrap();
1954            let ser = serde_json::to_string(&join_set_id).unwrap();
1955            let deser = serde_json::from_str(&ser).unwrap();
1956            assert_eq!(join_set_id, deser);
1957        }
1958    }
1959}