Skip to main content

objectiveai_sdk/functions/
function.rs

1//! Function types and client-side compilation.
2//!
3//! # Output Computation
4//!
5//! Functions do **not** have a top-level output expression. Instead, each task has its
6//! own `output` expression that transforms its raw result into a [`TaskOutputOwned`].
7//! The function's final output is computed as a **weighted average** of all task outputs
8//! using profile weights.
9//!
10//! - If a function has only 1 task, that task's output becomes the function's output directly
11//! - If a function has multiple tasks, each task's output is weighted and averaged
12//!
13//! Each task's `output` expression must return a valid `TaskOutputOwned` for the function's type:
14//! - **Scalar functions**: each task must return `Scalar(value)` where value is in [0, 1]
15//! - **Vector functions**: each task must return `Vector(values)` where values sum to ~1
16//!
17//! [`TaskOutputOwned`]: super::expression::TaskOutputOwned
18
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21
22/// A Function definition, either remote or inline.
23///
24/// Functions are composable scoring pipelines that transform structured input
25/// into scores. Each task has an `output` expression that transforms its raw result
26/// into a `TaskOutputOwned`. The function's final output is the weighted average of
27/// all task outputs using profile weights.
28///
29/// Use [`compile_tasks`](Self::compile_tasks) to preview how task expressions resolve
30/// for given inputs.
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
32#[serde(untagged)]
33#[schemars(rename = "functions.Function")]
34pub enum Function {
35    /// A remote function with metadata (description, schema, etc.).
36    #[schemars(title = "Remote")]
37    Remote(RemoteFunction),
38    /// An inline function definition without metadata.
39    #[schemars(title = "Inline")]
40    Inline(InlineFunction),
41}
42
43impl Function {
44    /// Validates the input against the function's input schema.
45    ///
46    /// For remote functions, checks whether the provided input conforms to
47    /// the function's JSON Schema definition. For inline functions, returns
48    /// `None` since they lack schema definitions.
49    ///
50    /// # Returns
51    ///
52    /// - `Some(true)` if the input is valid against the schema
53    /// - `Some(false)` if the input is invalid
54    /// - `None` for inline functions (no schema to validate against)
55    pub fn validate_input(
56        &self,
57        input: &super::expression::InputValue,
58    ) -> Option<bool> {
59        match self {
60            Function::Remote(remote_function) => {
61                Some(remote_function.input_schema().validate_input(input))
62            }
63            Function::Inline(_) => None,
64        }
65    }
66
67    /// Compiles task expressions to show the final tasks for a given input.
68    ///
69    /// Evaluates all expressions (JMESPath or Starlark) in the function's tasks
70    /// using the provided input data. Tasks with `skip` expressions that evaluate
71    /// to true return `None`. Tasks with `map` fields produce multiple task instances.
72    ///
73    /// # Returns
74    ///
75    /// A vector where each element corresponds to a task definition:
76    /// - `None` if the task was skipped
77    /// - `Some(CompiledTask::One(...))` for non-mapped tasks
78    /// - `Some(CompiledTask::Many(...))` for mapped tasks
79    pub fn compile_tasks(
80        self,
81        input: &super::expression::InputValue,
82    ) -> Result<
83        Vec<Option<super::CompiledTask>>,
84        super::expression::ExpressionError,
85    > {
86        // extract task expressions
87        let task_exprs = match self {
88            Function::Remote(RemoteFunction::Scalar { tasks, .. }) => tasks,
89            Function::Remote(RemoteFunction::Vector { tasks, .. }) => tasks,
90            Function::Inline(InlineFunction::Scalar { tasks, .. }) => tasks,
91            Function::Inline(InlineFunction::Vector { tasks, .. }) => tasks,
92        };
93
94        // prepare params for compiling expressions
95        let mut params =
96            super::expression::Params::Ref(super::expression::ParamsRef {
97                input,
98                output: None,
99                map: None,
100            });
101
102        // compile tasks
103        let mut tasks = Vec::with_capacity(task_exprs.len());
104        for mut task_expr in task_exprs {
105            tasks.push(
106                if let Some(skip_expr) = task_expr.take_skip()
107                    && skip_expr.compile_one::<bool>(&params)?
108                {
109                    // None if task is skipped
110                    None
111                } else if let Some(map_expr) = task_expr.map() {
112                    // evaluate map expression to get count
113                    let count: u64 = map_expr.compile_one(&params)?;
114                    // compile task for each map index
115                    let mut map_tasks = Vec::with_capacity(count as usize);
116                    for i in 0..count {
117                        // set map index
118                        match &mut params {
119                            super::expression::Params::Ref(params_ref) => {
120                                params_ref.map = Some(i);
121                            }
122                            _ => unreachable!(),
123                        }
124                        // compile task with map index
125                        map_tasks.push(task_expr.clone().compile(&params)?);
126                        // reset map index
127                        match &mut params {
128                            super::expression::Params::Ref(params_ref) => {
129                                params_ref.map = None;
130                            }
131                            _ => unreachable!(),
132                        }
133                    }
134                    Some(super::CompiledTask::Many(map_tasks))
135                } else {
136                    // compile single task
137                    Some(super::CompiledTask::One(task_expr.compile(&params)?))
138                },
139            );
140        }
141
142        // compiled tasks
143        Ok(tasks)
144    }
145
146    // /// Computes the final output given input and task outputs.
147    // ///
148    // /// Evaluates the function's output expression using the provided input data
149    // /// and task results. Also validates that the output meets constraints:
150    // /// - Scalar functions: output must be in [0, 1]
151    // /// - Vector functions: output must sum to approximately 1
152    // pub fn compile_output(
153    //     self,
154    //     input: &super::expression::InputValue,
155    //     task_outputs: &[Option<super::expression::TaskOutput>],
156    // ) -> Result<
157    //     super::expression::CompiledFunctionOutput,
158    //     super::expression::ExpressionError,
159    // > {
160    //     #[derive(Clone, Copy)]
161    //     enum FunctionType {
162    //         Scalar,
163    //         Vector,
164    //     }
165
166    //     // prepare params for compiling output_length expression
167    //     let mut params =
168    //         super::expression::Params::Ref(super::expression::ParamsRef {
169    //             input,
170    //             output: None,
171    //             map: None,
172    //         });
173
174    //     // extract output expression and output_length
175    //     let (function_type, output_expr, output_length) = match self {
176    //         Function::Remote(RemoteFunction::Scalar { output, .. }) => {
177    //             (FunctionType::Scalar, output, None)
178    //         }
179    //         Function::Remote(RemoteFunction::Vector {
180    //             output,
181    //             output_length,
182    //             ..
183    //         }) => (
184    //             FunctionType::Vector,
185    //             output,
186    //             Some(output_length.compile_one(&params)?),
187    //         ),
188    //         Function::Inline(InlineFunction::Scalar { output, .. }) => {
189    //             (FunctionType::Scalar, output, None)
190    //         }
191    //         Function::Inline(InlineFunction::Vector { output, .. }) => {
192    //             (FunctionType::Vector, output, None)
193    //         }
194    //     };
195
196    //     // prepare params for compiling output expression
197    //     match &mut params {
198    //         super::expression::Params::Ref(params_ref) => {
199    //             params_ref.tasks = task_outputs;
200    //         }
201    //         _ => unreachable!(),
202    //     }
203
204    //     // compile output
205    //     let output = output_expr
206    //         .compile_one::<super::expression::FunctionOutput>(&params)?;
207
208    //     // validate output
209    //     let valid = match (function_type, &output, output_length) {
210    //         (
211    //             FunctionType::Scalar,
212    //             &super::expression::FunctionOutput::Scalar(scalar),
213    //             _,
214    //         ) => {
215    //             scalar >= rust_decimal::Decimal::ZERO
216    //                 && scalar <= rust_decimal::Decimal::ONE
217    //         }
218    //         (
219    //             FunctionType::Vector,
220    //             super::expression::FunctionOutput::Vector(vector),
221    //             Some(length),
222    //         ) => {
223    //             let sum = vector.iter().sum::<rust_decimal::Decimal>();
224    //             vector.len() == length as usize
225    //                 && sum >= rust_decimal::dec!(0.99)
226    //                 && sum <= rust_decimal::dec!(1.01)
227    //         }
228    //         (
229    //             FunctionType::Vector,
230    //             super::expression::FunctionOutput::Vector(vector),
231    //             None,
232    //         ) => {
233    //             let sum = vector.iter().sum::<rust_decimal::Decimal>();
234    //             sum >= rust_decimal::dec!(0.99)
235    //                 && sum <= rust_decimal::dec!(1.01)
236    //         }
237    //         _ => false,
238    //     };
239
240    //     // compiled output
241    //     Ok(super::expression::CompiledFunctionOutput { output, valid })
242    // }
243
244    /// Computes the expected output length for a vector function.
245    ///
246    /// Evaluates the `output_length` expression to determine how many elements
247    /// the output vector should contain. This is only applicable to remote
248    /// vector functions which have an `output_length` field.
249    ///
250    /// # Arguments
251    ///
252    /// * `input` - The function input used to compute the output length
253    ///
254    /// # Returns
255    ///
256    /// - `Ok(Some(u64))` - The expected output length for remote vector functions
257    /// - `Ok(None)` - For scalar functions or inline functions
258    /// - `Err(ExpressionError)` - If the expression fails to compile
259    pub fn compile_output_length(
260        self,
261        input: &super::expression::InputValue,
262    ) -> Result<Option<u64>, super::expression::ExpressionError> {
263        let output_length_expr = match self {
264            Function::Remote(RemoteFunction::Scalar { .. }) => None,
265            Function::Remote(RemoteFunction::Vector {
266                output_length, ..
267            }) => Some(output_length),
268            Function::Inline(InlineFunction::Scalar { .. }) => None,
269            Function::Inline(InlineFunction::Vector { .. }) => None,
270        };
271        match output_length_expr {
272            Some(output_length_expr) => {
273                // prepare params for compiling output_length expression
274                let params = super::expression::Params::Ref(
275                    super::expression::ParamsRef {
276                        input,
277                        output: None,
278                        map: None,
279                    },
280                );
281                // compile output_length
282                let output_length = output_length_expr.compile_one(&params)?;
283                Ok(Some(output_length))
284            }
285            None => Ok(None),
286        }
287    }
288
289    /// Compiles the `input_split` expression to split input into multiple sub-inputs.
290    ///
291    /// Used by strategies like Swiss System that need to partition input into
292    /// smaller pools. The expression transforms the original input into an array
293    /// of inputs, where each element can be processed independently.
294    ///
295    /// # Arguments
296    ///
297    /// * `input` - The original function input to split
298    ///
299    /// # Returns
300    ///
301    /// - `Ok(Some(Vec<Input>))` - The split inputs for vector functions with `input_split` defined
302    /// - `Ok(None)` - For scalar functions or functions without `input_split`
303    /// - `Err(ExpressionError)` - If the expression fails to compile
304    pub fn compile_input_split(
305        self,
306        input: &super::expression::InputValue,
307    ) -> Result<
308        Option<Vec<super::expression::InputValue>>,
309        super::expression::ExpressionError,
310    > {
311        let input_split_expr = match self {
312            Function::Remote(RemoteFunction::Scalar { .. }) => None,
313            Function::Remote(RemoteFunction::Vector {
314                input_split, ..
315            }) => Some(input_split),
316            Function::Inline(InlineFunction::Scalar { .. }) => None,
317            Function::Inline(InlineFunction::Vector {
318                input_split, ..
319            }) => input_split,
320        };
321        match input_split_expr {
322            Some(input_split_expr) => {
323                // prepare params for compiling input_split expression
324                let params = super::expression::Params::Ref(
325                    super::expression::ParamsRef {
326                        input,
327                        output: None,
328                        map: None,
329                    },
330                );
331                // compile input_split
332                let input_split = input_split_expr.compile_one(&params)?;
333                Ok(Some(input_split))
334            }
335            None => Ok(None),
336        }
337    }
338
339    /// Compiles the `input_merge` expression to merge multiple sub-inputs back into one.
340    ///
341    /// Used by strategies like Swiss System to recombine a subset of split inputs
342    /// into a single input for pool execution. The expression transforms an array
343    /// of inputs (a subset from `input_split`) into a single merged input.
344    ///
345    /// # Arguments
346    ///
347    /// * `input` - An array of inputs to merge (typically a subset from `compile_input_split`)
348    ///
349    /// # Returns
350    ///
351    /// - `Ok(Some(Input))` - The merged input for vector functions with `input_merge` defined
352    /// - `Ok(None)` - For scalar functions or functions without `input_merge`
353    /// - `Err(ExpressionError)` - If the expression fails to compile
354    pub fn compile_input_merge(
355        self,
356        input: &super::expression::InputValue,
357    ) -> Result<
358        Option<super::expression::InputValue>,
359        super::expression::ExpressionError,
360    > {
361        let input_merge_expr = match self {
362            Function::Remote(RemoteFunction::Scalar { .. }) => None,
363            Function::Remote(RemoteFunction::Vector {
364                input_merge, ..
365            }) => Some(input_merge),
366            Function::Inline(InlineFunction::Scalar { .. }) => None,
367            Function::Inline(InlineFunction::Vector {
368                input_merge, ..
369            }) => input_merge,
370        };
371        match input_merge_expr {
372            Some(input_merge_expr) => {
373                // prepare params for compiling input_merge expression
374                let params = super::expression::Params::Ref(
375                    super::expression::ParamsRef {
376                        input,
377                        output: None,
378                        map: None,
379                    },
380                );
381                // compile input_merge
382                let input_merge = input_merge_expr.compile_one(&params)?;
383                Ok(Some(input_merge))
384            }
385            None => Ok(None),
386        }
387    }
388
389    /// Returns the function's description, if available.
390    pub fn description(&self) -> Option<&str> {
391        match self {
392            Function::Remote(remote_function) => {
393                Some(remote_function.description())
394            }
395            Function::Inline(_) => None,
396        }
397    }
398
399    /// Returns the function's input schema, if available.
400    pub fn input_schema(&self) -> Option<&super::expression::InputSchema> {
401        match self {
402            Function::Remote(remote_function) => {
403                Some(remote_function.input_schema())
404            }
405            Function::Inline(_) => None,
406        }
407    }
408
409    /// Returns the function's tasks.
410    pub fn tasks(&self) -> &[super::TaskExpression] {
411        match self {
412            Function::Remote(remote_function) => remote_function.tasks(),
413            Function::Inline(inline_function) => inline_function.tasks(),
414        }
415    }
416
417    /// Returns the function's expected output length expression, if defined.
418    pub fn output_length(&self) -> Option<&super::expression::Expression> {
419        match self {
420            Function::Remote(remote_function) => {
421                remote_function.output_length()
422            }
423            Function::Inline(_) => None,
424        }
425    }
426
427    /// Returns the function's input_split expression, if defined.
428    pub fn input_split(&self) -> Option<&super::expression::Expression> {
429        match self {
430            Function::Remote(remote_function) => remote_function.input_split(),
431            Function::Inline(inline_function) => inline_function.input_split(),
432        }
433    }
434
435    /// Returns the function's input_merge expression, if defined.
436    pub fn input_merge(&self) -> Option<&super::expression::Expression> {
437        match self {
438            Function::Remote(remote_function) => remote_function.input_merge(),
439            Function::Inline(inline_function) => inline_function.input_merge(),
440        }
441    }
442}
443
444/// A remote function with full metadata.
445///
446/// Remote functions are stored as `function.json` in repositories and
447/// referenced by `remote/owner/repository`. They include documentation fields
448/// that inline functions lack.
449#[derive(
450    Debug,
451    Clone,
452    PartialEq,
453    Serialize,
454    Deserialize,
455    JsonSchema,
456    arbitrary::Arbitrary,
457)]
458#[serde(tag = "type")]
459#[schemars(rename = "functions.RemoteFunction")]
460pub enum RemoteFunction {
461    /// Produces a single score in [0, 1].
462    #[schemars(title = "Scalar")]
463    #[serde(rename = "scalar.function")]
464    Scalar {
465        /// Human-readable description of what the function does.
466        description: String,
467        /// JSON Schema defining the expected input structure.
468        input_schema: super::expression::InputSchema,
469        /// The list of tasks to execute. Tasks with a `map` expression are
470        /// expanded into multiple instances. Each instance is compiled with
471        /// `map` set to the current integer index.
472        /// Receives: `input`, `map` (if mapped).
473        tasks: Vec<super::TaskExpression>,
474    },
475    /// Produces a vector of scores that sums to 1.
476    #[schemars(title = "Vector")]
477    #[serde(rename = "vector.function")]
478    Vector {
479        /// Human-readable description of what the function does.
480        description: String,
481        /// JSON Schema defining the expected input structure.
482        input_schema: super::expression::InputSchema,
483        /// The list of tasks to execute. Tasks with a `map` expression are
484        /// expanded into multiple instances. Each instance is compiled with
485        /// `map` set to the current integer index.
486        /// Receives: `input`, `map` (if mapped).
487        tasks: Vec<super::TaskExpression>,
488        /// Expression computing the expected output vector length for task outputs.
489        /// Receives: `input`.
490        output_length: super::expression::Expression,
491        /// Expression transforming input into an input array of the output_length
492        /// When the Function is executed with any input from the array,
493        /// The output_length should be 1.
494        /// Receives: `input`.
495        input_split: super::expression::Expression,
496        /// Expression transforming an array of inputs computed by `input_split`
497        /// into a single Input object for the Function.
498        /// Receives: `input` (as an array).
499        input_merge: super::expression::Expression,
500    },
501}
502
503impl RemoteFunction {
504    /// Returns the function's description.
505    pub fn description(&self) -> &str {
506        match self {
507            RemoteFunction::Scalar { description, .. } => description,
508            RemoteFunction::Vector { description, .. } => description,
509        }
510    }
511
512    /// Returns the function's input schema.
513    pub fn input_schema(&self) -> &super::expression::InputSchema {
514        match self {
515            RemoteFunction::Scalar { input_schema, .. } => input_schema,
516            RemoteFunction::Vector { input_schema, .. } => input_schema,
517        }
518    }
519
520    /// Returns the function's tasks.
521    pub fn tasks(&self) -> &[super::TaskExpression] {
522        match self {
523            RemoteFunction::Scalar { tasks, .. } => tasks,
524            RemoteFunction::Vector { tasks, .. } => tasks,
525        }
526    }
527
528    /// Returns the function's expected output length, if defined (vector functions only).
529    pub fn output_length(&self) -> Option<&super::expression::Expression> {
530        match self {
531            RemoteFunction::Scalar { .. } => None,
532            RemoteFunction::Vector { output_length, .. } => Some(output_length),
533        }
534    }
535
536    /// Returns the function's input_split expression, if defined (vector functions only).
537    pub fn input_split(&self) -> Option<&super::expression::Expression> {
538        match self {
539            RemoteFunction::Scalar { .. } => None,
540            RemoteFunction::Vector { input_split, .. } => Some(input_split),
541        }
542    }
543
544    /// Returns the function's input_merge expression, if defined (vector functions only).
545    pub fn input_merge(&self) -> Option<&super::expression::Expression> {
546        match self {
547            RemoteFunction::Scalar { .. } => None,
548            RemoteFunction::Vector { input_merge, .. } => Some(input_merge),
549        }
550    }
551
552    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
553        self.tasks().iter().filter_map(|task| match task {
554            super::TaskExpression::ScalarFunction(t) => Some(&t.path),
555            super::TaskExpression::VectorFunction(t) => Some(&t.path),
556            _ => None,
557        })
558    }
559}
560
561/// An inline function definition without metadata.
562///
563/// Used when embedding function logic directly in requests rather than
564/// referencing a remote function. Lacks description and input
565/// schema fields.
566#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
567#[serde(tag = "type")]
568#[schemars(rename = "functions.InlineFunction")]
569pub enum InlineFunction {
570    /// Produces a single score in [0, 1].
571    #[schemars(title = "Scalar")]
572    #[serde(rename = "scalar.function")]
573    Scalar {
574        /// The list of tasks to execute. Tasks with a `map` expression are
575        /// expanded into multiple instances. Each instance is compiled with
576        /// `map` set to the current integer index.
577        /// Receives: `input`, `map` (if mapped).
578        tasks: Vec<super::TaskExpression>,
579    },
580    /// Produces a vector of scores that sums to 1.
581    #[schemars(title = "Vector")]
582    #[serde(rename = "vector.function")]
583    Vector {
584        /// The list of tasks to execute. Tasks with a `map` expression are
585        /// expanded into multiple instances. Each instance is compiled with
586        /// `map` set to the current integer index.
587        /// Receives: `input`, `map` (if mapped).
588        tasks: Vec<super::TaskExpression>,
589        /// Expression transforming input into an input array of the output_length
590        /// When the Function is executed with any input from the array,
591        /// The output_length should be 1.
592        /// Receives: `input`.
593        /// Only required if the request uses a strategy that needs input splitting.
594        input_split: Option<super::expression::Expression>,
595        /// Expression transforming an array of inputs computed by `input_split`
596        /// into a single Input object for the Function.
597        /// Receives: `input` (as an array).
598        /// Only required if the request uses a strategy that needs input splitting.
599        input_merge: Option<super::expression::Expression>,
600    },
601}
602
603impl InlineFunction {
604    /// Returns the function's tasks.
605    pub fn tasks(&self) -> &[super::TaskExpression] {
606        match self {
607            InlineFunction::Scalar { tasks, .. } => tasks,
608            InlineFunction::Vector { tasks, .. } => tasks,
609        }
610    }
611
612    /// Returns the function's input_split expression, if defined (vector functions only).
613    pub fn input_split(&self) -> Option<&super::expression::Expression> {
614        match self {
615            InlineFunction::Scalar { .. } => None,
616            InlineFunction::Vector { input_split, .. } => input_split.as_ref(),
617        }
618    }
619
620    /// Returns the function's input_merge expression, if defined (vector functions only).
621    pub fn input_merge(&self) -> Option<&super::expression::Expression> {
622        match self {
623            InlineFunction::Scalar { .. } => None,
624            InlineFunction::Vector { input_merge, .. } => input_merge.as_ref(),
625        }
626    }
627
628    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
629        self.tasks().iter().filter_map(|task| match task {
630            super::TaskExpression::ScalarFunction(t) => Some(&t.path),
631            super::TaskExpression::VectorFunction(t) => Some(&t.path),
632            _ => None,
633        })
634    }
635}
636
637#[derive(
638    Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema,
639)]
640#[schemars(rename = "functions.FunctionType")]
641pub enum FunctionType {
642    #[schemars(title = "Scalar")]
643    #[serde(rename = "scalar.function")]
644    Scalar,
645    #[schemars(title = "Vector")]
646    #[serde(rename = "vector.function")]
647    Vector,
648}