Skip to main content

objectiveai_sdk/functions/
function.rs

1//! Function types and client-side compilation.
2//!
3//! # Output Computation
4//!
5//! Functions do **not** have a top-level output expression. Instead, each task has its
6//! own `output` expression that transforms its raw result into a [`TaskOutputOwned`].
7//! The function's final output is computed as a **weighted average** of all task outputs
8//! using profile weights.
9//!
10//! - If a function has only 1 task, that task's output becomes the function's output directly
11//! - If a function has multiple tasks, each task's output is weighted and averaged
12//!
13//! Each task's `output` expression must return a valid `TaskOutputOwned` for the function's type:
14//! - **Scalar functions**: each task must return `Scalar(value)` where value is in [0, 1]
15//! - **Vector functions**: each task must return `Vector(values)` where values sum to ~1
16//!
17//! [`TaskOutputOwned`]: super::expression::TaskOutputOwned
18
19use serde::{Deserialize, Serialize};
20use schemars::JsonSchema;
21
22/// A Function definition, either remote or inline.
23///
24/// Functions are composable scoring pipelines that transform structured input
25/// into scores. Each task has an `output` expression that transforms its raw result
26/// into a `TaskOutputOwned`. The function's final output is the weighted average of
27/// all task outputs using profile weights.
28///
29/// Use [`compile_tasks`](Self::compile_tasks) to preview how task expressions resolve
30/// for given inputs.
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
32#[serde(untagged)]
33#[schemars(rename = "functions.Function")]
34pub enum Function {
35    /// A remote function with metadata (description, schema, etc.).
36    #[schemars(title = "Remote")]
37    Remote(RemoteFunction),
38    /// An inline function definition without metadata.
39    #[schemars(title = "Inline")]
40    Inline(InlineFunction),
41}
42
43impl Function {
44    /// Validates the input against the function's input schema.
45    ///
46    /// For remote functions, checks whether the provided input conforms to
47    /// the function's JSON Schema definition. For inline functions, returns
48    /// `None` since they lack schema definitions.
49    ///
50    /// # Returns
51    ///
52    /// - `Some(true)` if the input is valid against the schema
53    /// - `Some(false)` if the input is invalid
54    /// - `None` for inline functions (no schema to validate against)
55    pub fn validate_input(
56        &self,
57        input: &super::expression::InputValue,
58    ) -> Option<bool> {
59        match self {
60            Function::Remote(remote_function) => {
61                Some(remote_function.input_schema().validate_input(input))
62            }
63            Function::Inline(_) => None,
64        }
65    }
66
67    /// Compiles task expressions to show the final tasks for a given input.
68    ///
69    /// Evaluates all expressions (JMESPath or Starlark) in the function's tasks
70    /// using the provided input data. Tasks with `skip` expressions that evaluate
71    /// to true return `None`. Tasks with `map` fields produce multiple task instances.
72    ///
73    /// # Returns
74    ///
75    /// A vector where each element corresponds to a task definition:
76    /// - `None` if the task was skipped
77    /// - `Some(CompiledTask::One(...))` for non-mapped tasks
78    /// - `Some(CompiledTask::Many(...))` for mapped tasks
79    pub fn compile_tasks(
80        self,
81        input: &super::expression::InputValue,
82    ) -> Result<
83        Vec<Option<super::CompiledTask>>,
84        super::expression::ExpressionError,
85    > {
86        // extract task expressions
87        let task_exprs = match self {
88            Function::Remote(RemoteFunction::Scalar { tasks, .. }) => tasks,
89            Function::Remote(RemoteFunction::Vector { tasks, .. }) => tasks,
90            Function::Inline(InlineFunction::Scalar { tasks, .. }) => tasks,
91            Function::Inline(InlineFunction::Vector { tasks, .. }) => tasks,
92        };
93
94        // prepare params for compiling expressions
95        let mut params =
96            super::expression::Params::Ref(super::expression::ParamsRef {
97                input,
98                output: None,
99                map: None,
100                tasks_min: None,
101                tasks_max: None,
102                depth: None,
103                name: None,
104                spec: None,
105            });
106
107        // compile tasks
108        let mut tasks = Vec::with_capacity(task_exprs.len());
109        for mut task_expr in task_exprs {
110            tasks.push(
111                if let Some(skip_expr) = task_expr.take_skip()
112                    && skip_expr.compile_one::<bool>(&params)?
113                {
114                    // None if task is skipped
115                    None
116                } else if let Some(map_expr) = task_expr.map() {
117                    // evaluate map expression to get count
118                    let count: u64 = map_expr.compile_one(&params)?;
119                    // compile task for each map index
120                    let mut map_tasks = Vec::with_capacity(count as usize);
121                    for i in 0..count {
122                        // set map index
123                        match &mut params {
124                            super::expression::Params::Ref(params_ref) => {
125                                params_ref.map = Some(i);
126                            }
127                            _ => unreachable!(),
128                        }
129                        // compile task with map index
130                        map_tasks.push(task_expr.clone().compile(&params)?);
131                        // reset map index
132                        match &mut params {
133                            super::expression::Params::Ref(params_ref) => {
134                                params_ref.map = None;
135                            }
136                            _ => unreachable!(),
137                        }
138                    }
139                    Some(super::CompiledTask::Many(map_tasks))
140                } else {
141                    // compile single task
142                    Some(super::CompiledTask::One(task_expr.compile(&params)?))
143                },
144            );
145        }
146
147        // compiled tasks
148        Ok(tasks)
149    }
150
151    // /// Computes the final output given input and task outputs.
152    // ///
153    // /// Evaluates the function's output expression using the provided input data
154    // /// and task results. Also validates that the output meets constraints:
155    // /// - Scalar functions: output must be in [0, 1]
156    // /// - Vector functions: output must sum to approximately 1
157    // pub fn compile_output(
158    //     self,
159    //     input: &super::expression::InputValue,
160    //     task_outputs: &[Option<super::expression::TaskOutput>],
161    // ) -> Result<
162    //     super::expression::CompiledFunctionOutput,
163    //     super::expression::ExpressionError,
164    // > {
165    //     #[derive(Clone, Copy)]
166    //     enum FunctionType {
167    //         Scalar,
168    //         Vector,
169    //     }
170
171    //     // prepare params for compiling output_length expression
172    //     let mut params =
173    //         super::expression::Params::Ref(super::expression::ParamsRef {
174    //             input,
175    //             output: None,
176    //             map: None,
177    //         });
178
179    //     // extract output expression and output_length
180    //     let (function_type, output_expr, output_length) = match self {
181    //         Function::Remote(RemoteFunction::Scalar { output, .. }) => {
182    //             (FunctionType::Scalar, output, None)
183    //         }
184    //         Function::Remote(RemoteFunction::Vector {
185    //             output,
186    //             output_length,
187    //             ..
188    //         }) => (
189    //             FunctionType::Vector,
190    //             output,
191    //             Some(output_length.compile_one(&params)?),
192    //         ),
193    //         Function::Inline(InlineFunction::Scalar { output, .. }) => {
194    //             (FunctionType::Scalar, output, None)
195    //         }
196    //         Function::Inline(InlineFunction::Vector { output, .. }) => {
197    //             (FunctionType::Vector, output, None)
198    //         }
199    //     };
200
201    //     // prepare params for compiling output expression
202    //     match &mut params {
203    //         super::expression::Params::Ref(params_ref) => {
204    //             params_ref.tasks = task_outputs;
205    //         }
206    //         _ => unreachable!(),
207    //     }
208
209    //     // compile output
210    //     let output = output_expr
211    //         .compile_one::<super::expression::FunctionOutput>(&params)?;
212
213    //     // validate output
214    //     let valid = match (function_type, &output, output_length) {
215    //         (
216    //             FunctionType::Scalar,
217    //             &super::expression::FunctionOutput::Scalar(scalar),
218    //             _,
219    //         ) => {
220    //             scalar >= rust_decimal::Decimal::ZERO
221    //                 && scalar <= rust_decimal::Decimal::ONE
222    //         }
223    //         (
224    //             FunctionType::Vector,
225    //             super::expression::FunctionOutput::Vector(vector),
226    //             Some(length),
227    //         ) => {
228    //             let sum = vector.iter().sum::<rust_decimal::Decimal>();
229    //             vector.len() == length as usize
230    //                 && sum >= rust_decimal::dec!(0.99)
231    //                 && sum <= rust_decimal::dec!(1.01)
232    //         }
233    //         (
234    //             FunctionType::Vector,
235    //             super::expression::FunctionOutput::Vector(vector),
236    //             None,
237    //         ) => {
238    //             let sum = vector.iter().sum::<rust_decimal::Decimal>();
239    //             sum >= rust_decimal::dec!(0.99)
240    //                 && sum <= rust_decimal::dec!(1.01)
241    //         }
242    //         _ => false,
243    //     };
244
245    //     // compiled output
246    //     Ok(super::expression::CompiledFunctionOutput { output, valid })
247    // }
248
249    /// Computes the expected output length for a vector function.
250    ///
251    /// Evaluates the `output_length` expression to determine how many elements
252    /// the output vector should contain. This is only applicable to remote
253    /// vector functions which have an `output_length` field.
254    ///
255    /// # Arguments
256    ///
257    /// * `input` - The function input used to compute the output length
258    ///
259    /// # Returns
260    ///
261    /// - `Ok(Some(u64))` - The expected output length for remote vector functions
262    /// - `Ok(None)` - For scalar functions or inline functions
263    /// - `Err(ExpressionError)` - If the expression fails to compile
264    pub fn compile_output_length(
265        self,
266        input: &super::expression::InputValue,
267    ) -> Result<Option<u64>, super::expression::ExpressionError> {
268        let output_length_expr = match self {
269            Function::Remote(RemoteFunction::Scalar { .. }) => None,
270            Function::Remote(RemoteFunction::Vector {
271                output_length, ..
272            }) => Some(output_length),
273            Function::Inline(InlineFunction::Scalar { .. }) => None,
274            Function::Inline(InlineFunction::Vector { .. }) => None,
275        };
276        match output_length_expr {
277            Some(output_length_expr) => {
278                // prepare params for compiling output_length expression
279                let params = super::expression::Params::Ref(
280                    super::expression::ParamsRef {
281                        input,
282                        output: None,
283                        map: None,
284                        tasks_min: None,
285                        tasks_max: None,
286                        depth: None,
287                        name: None,
288                        spec: None,
289                    },
290                );
291                // compile output_length
292                let output_length = output_length_expr.compile_one(&params)?;
293                Ok(Some(output_length))
294            }
295            None => Ok(None),
296        }
297    }
298
299    /// Compiles the `input_split` expression to split input into multiple sub-inputs.
300    ///
301    /// Used by strategies like Swiss System that need to partition input into
302    /// smaller pools. The expression transforms the original input into an array
303    /// of inputs, where each element can be processed independently.
304    ///
305    /// # Arguments
306    ///
307    /// * `input` - The original function input to split
308    ///
309    /// # Returns
310    ///
311    /// - `Ok(Some(Vec<Input>))` - The split inputs for vector functions with `input_split` defined
312    /// - `Ok(None)` - For scalar functions or functions without `input_split`
313    /// - `Err(ExpressionError)` - If the expression fails to compile
314    pub fn compile_input_split(
315        self,
316        input: &super::expression::InputValue,
317    ) -> Result<
318        Option<Vec<super::expression::InputValue>>,
319        super::expression::ExpressionError,
320    > {
321        let input_split_expr = match self {
322            Function::Remote(RemoteFunction::Scalar { .. }) => None,
323            Function::Remote(RemoteFunction::Vector {
324                input_split, ..
325            }) => Some(input_split),
326            Function::Inline(InlineFunction::Scalar { .. }) => None,
327            Function::Inline(InlineFunction::Vector {
328                input_split, ..
329            }) => input_split,
330        };
331        match input_split_expr {
332            Some(input_split_expr) => {
333                // prepare params for compiling input_split expression
334                let params = super::expression::Params::Ref(
335                    super::expression::ParamsRef {
336                        input,
337                        output: None,
338                        map: None,
339                        tasks_min: None,
340                        tasks_max: None,
341                        depth: None,
342                        name: None,
343                        spec: None,
344                    },
345                );
346                // compile input_split
347                let input_split = input_split_expr.compile_one(&params)?;
348                Ok(Some(input_split))
349            }
350            None => Ok(None),
351        }
352    }
353
354    /// Compiles the `input_merge` expression to merge multiple sub-inputs back into one.
355    ///
356    /// Used by strategies like Swiss System to recombine a subset of split inputs
357    /// into a single input for pool execution. The expression transforms an array
358    /// of inputs (a subset from `input_split`) into a single merged input.
359    ///
360    /// # Arguments
361    ///
362    /// * `input` - An array of inputs to merge (typically a subset from `compile_input_split`)
363    ///
364    /// # Returns
365    ///
366    /// - `Ok(Some(Input))` - The merged input for vector functions with `input_merge` defined
367    /// - `Ok(None)` - For scalar functions or functions without `input_merge`
368    /// - `Err(ExpressionError)` - If the expression fails to compile
369    pub fn compile_input_merge(
370        self,
371        input: &super::expression::InputValue,
372    ) -> Result<
373        Option<super::expression::InputValue>,
374        super::expression::ExpressionError,
375    > {
376        let input_merge_expr = match self {
377            Function::Remote(RemoteFunction::Scalar { .. }) => None,
378            Function::Remote(RemoteFunction::Vector {
379                input_merge, ..
380            }) => Some(input_merge),
381            Function::Inline(InlineFunction::Scalar { .. }) => None,
382            Function::Inline(InlineFunction::Vector {
383                input_merge, ..
384            }) => input_merge,
385        };
386        match input_merge_expr {
387            Some(input_merge_expr) => {
388                // prepare params for compiling input_merge expression
389                let params = super::expression::Params::Ref(
390                    super::expression::ParamsRef {
391                        input,
392                        output: None,
393                        map: None,
394                        tasks_min: None,
395                        tasks_max: None,
396                        depth: None,
397                        name: None,
398                        spec: None,
399                    },
400                );
401                // compile input_merge
402                let input_merge = input_merge_expr.compile_one(&params)?;
403                Ok(Some(input_merge))
404            }
405            None => Ok(None),
406        }
407    }
408
409    /// Returns the function's description, if available.
410    pub fn description(&self) -> Option<&str> {
411        match self {
412            Function::Remote(remote_function) => {
413                Some(remote_function.description())
414            }
415            Function::Inline(_) => None,
416        }
417    }
418
419    /// Returns the function's input schema, if available.
420    pub fn input_schema(&self) -> Option<&super::expression::InputSchema> {
421        match self {
422            Function::Remote(remote_function) => {
423                Some(remote_function.input_schema())
424            }
425            Function::Inline(_) => None,
426        }
427    }
428
429    /// Returns the function's tasks.
430    pub fn tasks(&self) -> &[super::TaskExpression] {
431        match self {
432            Function::Remote(remote_function) => remote_function.tasks(),
433            Function::Inline(inline_function) => inline_function.tasks(),
434        }
435    }
436
437    /// Returns the function's expected output length expression, if defined.
438    pub fn output_length(&self) -> Option<&super::expression::Expression> {
439        match self {
440            Function::Remote(remote_function) => {
441                remote_function.output_length()
442            }
443            Function::Inline(_) => None,
444        }
445    }
446
447    /// Returns the function's input_split expression, if defined.
448    pub fn input_split(&self) -> Option<&super::expression::Expression> {
449        match self {
450            Function::Remote(remote_function) => remote_function.input_split(),
451            Function::Inline(inline_function) => inline_function.input_split(),
452        }
453    }
454
455    /// Returns the function's input_merge expression, if defined.
456    pub fn input_merge(&self) -> Option<&super::expression::Expression> {
457        match self {
458            Function::Remote(remote_function) => remote_function.input_merge(),
459            Function::Inline(inline_function) => inline_function.input_merge(),
460        }
461    }
462}
463
464/// A remote function with full metadata.
465///
466/// Remote functions are stored as `function.json` in repositories and
467/// referenced by `remote/owner/repository`. They include documentation fields
468/// that inline functions lack.
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
470#[serde(tag = "type")]
471#[schemars(rename = "functions.RemoteFunction")]
472pub enum RemoteFunction {
473    /// Produces a single score in [0, 1].
474    #[schemars(title = "Scalar")]
475    #[serde(rename = "scalar.function")]
476    Scalar {
477        /// Human-readable description of what the function does.
478        description: String,
479        /// JSON Schema defining the expected input structure.
480        input_schema: super::expression::InputSchema,
481        /// The list of tasks to execute. Tasks with a `map` expression are
482        /// expanded into multiple instances. Each instance is compiled with
483        /// `map` set to the current integer index.
484        /// Receives: `input`, `map` (if mapped).
485        tasks: Vec<super::TaskExpression>,
486    },
487    /// Produces a vector of scores that sums to 1.
488    #[schemars(title = "Vector")]
489    #[serde(rename = "vector.function")]
490    Vector {
491        /// Human-readable description of what the function does.
492        description: String,
493        /// JSON Schema defining the expected input structure.
494        input_schema: super::expression::InputSchema,
495        /// The list of tasks to execute. Tasks with a `map` expression are
496        /// expanded into multiple instances. Each instance is compiled with
497        /// `map` set to the current integer index.
498        /// Receives: `input`, `map` (if mapped).
499        tasks: Vec<super::TaskExpression>,
500        /// Expression computing the expected output vector length for task outputs.
501        /// Receives: `input`.
502        output_length: super::expression::Expression,
503        /// Expression transforming input into an input array of the output_length
504        /// When the Function is executed with any input from the array,
505        /// The output_length should be 1.
506        /// Receives: `input`.
507        input_split: super::expression::Expression,
508        /// Expression transforming an array of inputs computed by `input_split`
509        /// into a single Input object for the Function.
510        /// Receives: `input` (as an array).
511        input_merge: super::expression::Expression,
512    },
513}
514
515impl RemoteFunction {
516    /// Returns the function's description.
517    pub fn description(&self) -> &str {
518        match self {
519            RemoteFunction::Scalar { description, .. } => description,
520            RemoteFunction::Vector { description, .. } => description,
521        }
522    }
523
524    /// Returns the function's input schema.
525    pub fn input_schema(&self) -> &super::expression::InputSchema {
526        match self {
527            RemoteFunction::Scalar { input_schema, .. } => input_schema,
528            RemoteFunction::Vector { input_schema, .. } => input_schema,
529        }
530    }
531
532    /// Returns the function's tasks.
533    pub fn tasks(&self) -> &[super::TaskExpression] {
534        match self {
535            RemoteFunction::Scalar { tasks, .. } => tasks,
536            RemoteFunction::Vector { tasks, .. } => tasks,
537        }
538    }
539
540    /// Returns the function's expected output length, if defined (vector functions only).
541    pub fn output_length(&self) -> Option<&super::expression::Expression> {
542        match self {
543            RemoteFunction::Scalar { .. } => None,
544            RemoteFunction::Vector { output_length, .. } => Some(output_length),
545        }
546    }
547
548    /// Returns the function's input_split expression, if defined (vector functions only).
549    pub fn input_split(&self) -> Option<&super::expression::Expression> {
550        match self {
551            RemoteFunction::Scalar { .. } => None,
552            RemoteFunction::Vector { input_split, .. } => Some(input_split),
553        }
554    }
555
556    /// Returns the function's input_merge expression, if defined (vector functions only).
557    pub fn input_merge(&self) -> Option<&super::expression::Expression> {
558        match self {
559            RemoteFunction::Scalar { .. } => None,
560            RemoteFunction::Vector { input_merge, .. } => Some(input_merge),
561        }
562    }
563
564    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
565        self.tasks().iter().filter_map(|task| match task {
566            super::TaskExpression::ScalarFunction(t) => Some(&t.path),
567            super::TaskExpression::VectorFunction(t) => Some(&t.path),
568            _ => None,
569        })
570    }
571}
572
573/// An inline function definition without metadata.
574///
575/// Used when embedding function logic directly in requests rather than
576/// referencing a remote function. Lacks description and input
577/// schema fields.
578#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
579#[serde(tag = "type")]
580#[schemars(rename = "functions.InlineFunction")]
581pub enum InlineFunction {
582    /// Produces a single score in [0, 1].
583    #[schemars(title = "Scalar")]
584    #[serde(rename = "scalar.function")]
585    Scalar {
586        /// The list of tasks to execute. Tasks with a `map` expression are
587        /// expanded into multiple instances. Each instance is compiled with
588        /// `map` set to the current integer index.
589        /// Receives: `input`, `map` (if mapped).
590        tasks: Vec<super::TaskExpression>,
591    },
592    /// Produces a vector of scores that sums to 1.
593    #[schemars(title = "Vector")]
594    #[serde(rename = "vector.function")]
595    Vector {
596        /// The list of tasks to execute. Tasks with a `map` expression are
597        /// expanded into multiple instances. Each instance is compiled with
598        /// `map` set to the current integer index.
599        /// Receives: `input`, `map` (if mapped).
600        tasks: Vec<super::TaskExpression>,
601        /// Expression transforming input into an input array of the output_length
602        /// When the Function is executed with any input from the array,
603        /// The output_length should be 1.
604        /// Receives: `input`.
605        /// Only required if the request uses a strategy that needs input splitting.
606        input_split: Option<super::expression::Expression>,
607        /// Expression transforming an array of inputs computed by `input_split`
608        /// into a single Input object for the Function.
609        /// Receives: `input` (as an array).
610        /// Only required if the request uses a strategy that needs input splitting.
611        input_merge: Option<super::expression::Expression>,
612    },
613}
614
615impl InlineFunction {
616    /// Returns the function's tasks.
617    pub fn tasks(&self) -> &[super::TaskExpression] {
618        match self {
619            InlineFunction::Scalar { tasks, .. } => tasks,
620            InlineFunction::Vector { tasks, .. } => tasks,
621        }
622    }
623
624    /// Returns the function's input_split expression, if defined (vector functions only).
625    pub fn input_split(&self) -> Option<&super::expression::Expression> {
626        match self {
627            InlineFunction::Scalar { .. } => None,
628            InlineFunction::Vector { input_split, .. } => input_split.as_ref(),
629        }
630    }
631
632    /// Returns the function's input_merge expression, if defined (vector functions only).
633    pub fn input_merge(&self) -> Option<&super::expression::Expression> {
634        match self {
635            InlineFunction::Scalar { .. } => None,
636            InlineFunction::Vector { input_merge, .. } => input_merge.as_ref(),
637        }
638    }
639
640    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
641        self.tasks().iter().filter_map(|task| match task {
642            super::TaskExpression::ScalarFunction(t) => Some(&t.path),
643            super::TaskExpression::VectorFunction(t) => Some(&t.path),
644            _ => None,
645        })
646    }
647}
648
649#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
650#[schemars(rename = "functions.FunctionType")]
651pub enum FunctionType {
652    #[schemars(title = "Scalar")]
653    #[serde(rename = "scalar.function")]
654    Scalar,
655    #[schemars(title = "Vector")]
656    #[serde(rename = "vector.function")]
657    Vector,
658}