objectiveai_sdk/functions/function.rs
1//! Function types and client-side compilation.
2//!
3//! # Output Computation
4//!
5//! Functions do **not** have a top-level output expression. Instead, each task has its
6//! own `output` expression that transforms its raw result into a [`TaskOutputOwned`].
7//! The function's final output is computed as a **weighted average** of all task outputs
8//! using profile weights.
9//!
10//! - If a function has only 1 task, that task's output becomes the function's output directly
11//! - If a function has multiple tasks, each task's output is weighted and averaged
12//!
13//! Each task's `output` expression must return a valid `TaskOutputOwned` for the function's type:
14//! - **Scalar functions**: each task must return `Scalar(value)` where value is in [0, 1]
15//! - **Vector functions**: each task must return `Vector(values)` where values sum to ~1
16//!
17//! [`TaskOutputOwned`]: super::expression::TaskOutputOwned
18
19use serde::{Deserialize, Serialize};
20use schemars::JsonSchema;
21
22/// A Function definition, either remote or inline.
23///
24/// Functions are composable scoring pipelines that transform structured input
25/// into scores. Each task has an `output` expression that transforms its raw result
26/// into a `TaskOutputOwned`. The function's final output is the weighted average of
27/// all task outputs using profile weights.
28///
29/// Use [`compile_tasks`](Self::compile_tasks) to preview how task expressions resolve
30/// for given inputs.
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
32#[serde(untagged)]
33#[schemars(rename = "functions.Function")]
34pub enum Function {
35 /// A remote function with metadata (description, schema, etc.).
36 #[schemars(title = "Remote")]
37 Remote(RemoteFunction),
38 /// An inline function definition without metadata.
39 #[schemars(title = "Inline")]
40 Inline(InlineFunction),
41}
42
43impl Function {
44 /// Validates the input against the function's input schema.
45 ///
46 /// For remote functions, checks whether the provided input conforms to
47 /// the function's JSON Schema definition. For inline functions, returns
48 /// `None` since they lack schema definitions.
49 ///
50 /// # Returns
51 ///
52 /// - `Some(true)` if the input is valid against the schema
53 /// - `Some(false)` if the input is invalid
54 /// - `None` for inline functions (no schema to validate against)
55 pub fn validate_input(
56 &self,
57 input: &super::expression::InputValue,
58 ) -> Option<bool> {
59 match self {
60 Function::Remote(remote_function) => {
61 Some(remote_function.input_schema().validate_input(input))
62 }
63 Function::Inline(_) => None,
64 }
65 }
66
67 /// Compiles task expressions to show the final tasks for a given input.
68 ///
69 /// Evaluates all expressions (JMESPath or Starlark) in the function's tasks
70 /// using the provided input data. Tasks with `skip` expressions that evaluate
71 /// to true return `None`. Tasks with `map` fields produce multiple task instances.
72 ///
73 /// # Returns
74 ///
75 /// A vector where each element corresponds to a task definition:
76 /// - `None` if the task was skipped
77 /// - `Some(CompiledTask::One(...))` for non-mapped tasks
78 /// - `Some(CompiledTask::Many(...))` for mapped tasks
79 pub fn compile_tasks(
80 self,
81 input: &super::expression::InputValue,
82 ) -> Result<
83 Vec<Option<super::CompiledTask>>,
84 super::expression::ExpressionError,
85 > {
86 // extract task expressions
87 let task_exprs = match self {
88 Function::Remote(RemoteFunction::Scalar { tasks, .. }) => tasks,
89 Function::Remote(RemoteFunction::Vector { tasks, .. }) => tasks,
90 Function::Inline(InlineFunction::Scalar { tasks, .. }) => tasks,
91 Function::Inline(InlineFunction::Vector { tasks, .. }) => tasks,
92 };
93
94 // prepare params for compiling expressions
95 let mut params =
96 super::expression::Params::Ref(super::expression::ParamsRef {
97 input,
98 output: None,
99 map: None,
100 tasks_min: None,
101 tasks_max: None,
102 depth: None,
103 name: None,
104 spec: None,
105 });
106
107 // compile tasks
108 let mut tasks = Vec::with_capacity(task_exprs.len());
109 for mut task_expr in task_exprs {
110 tasks.push(
111 if let Some(skip_expr) = task_expr.take_skip()
112 && skip_expr.compile_one::<bool>(¶ms)?
113 {
114 // None if task is skipped
115 None
116 } else if let Some(map_expr) = task_expr.map() {
117 // evaluate map expression to get count
118 let count: u64 = map_expr.compile_one(¶ms)?;
119 // compile task for each map index
120 let mut map_tasks = Vec::with_capacity(count as usize);
121 for i in 0..count {
122 // set map index
123 match &mut params {
124 super::expression::Params::Ref(params_ref) => {
125 params_ref.map = Some(i);
126 }
127 _ => unreachable!(),
128 }
129 // compile task with map index
130 map_tasks.push(task_expr.clone().compile(¶ms)?);
131 // reset map index
132 match &mut params {
133 super::expression::Params::Ref(params_ref) => {
134 params_ref.map = None;
135 }
136 _ => unreachable!(),
137 }
138 }
139 Some(super::CompiledTask::Many(map_tasks))
140 } else {
141 // compile single task
142 Some(super::CompiledTask::One(task_expr.compile(¶ms)?))
143 },
144 );
145 }
146
147 // compiled tasks
148 Ok(tasks)
149 }
150
151 // /// Computes the final output given input and task outputs.
152 // ///
153 // /// Evaluates the function's output expression using the provided input data
154 // /// and task results. Also validates that the output meets constraints:
155 // /// - Scalar functions: output must be in [0, 1]
156 // /// - Vector functions: output must sum to approximately 1
157 // pub fn compile_output(
158 // self,
159 // input: &super::expression::InputValue,
160 // task_outputs: &[Option<super::expression::TaskOutput>],
161 // ) -> Result<
162 // super::expression::CompiledFunctionOutput,
163 // super::expression::ExpressionError,
164 // > {
165 // #[derive(Clone, Copy)]
166 // enum FunctionType {
167 // Scalar,
168 // Vector,
169 // }
170
171 // // prepare params for compiling output_length expression
172 // let mut params =
173 // super::expression::Params::Ref(super::expression::ParamsRef {
174 // input,
175 // output: None,
176 // map: None,
177 // });
178
179 // // extract output expression and output_length
180 // let (function_type, output_expr, output_length) = match self {
181 // Function::Remote(RemoteFunction::Scalar { output, .. }) => {
182 // (FunctionType::Scalar, output, None)
183 // }
184 // Function::Remote(RemoteFunction::Vector {
185 // output,
186 // output_length,
187 // ..
188 // }) => (
189 // FunctionType::Vector,
190 // output,
191 // Some(output_length.compile_one(¶ms)?),
192 // ),
193 // Function::Inline(InlineFunction::Scalar { output, .. }) => {
194 // (FunctionType::Scalar, output, None)
195 // }
196 // Function::Inline(InlineFunction::Vector { output, .. }) => {
197 // (FunctionType::Vector, output, None)
198 // }
199 // };
200
201 // // prepare params for compiling output expression
202 // match &mut params {
203 // super::expression::Params::Ref(params_ref) => {
204 // params_ref.tasks = task_outputs;
205 // }
206 // _ => unreachable!(),
207 // }
208
209 // // compile output
210 // let output = output_expr
211 // .compile_one::<super::expression::FunctionOutput>(¶ms)?;
212
213 // // validate output
214 // let valid = match (function_type, &output, output_length) {
215 // (
216 // FunctionType::Scalar,
217 // &super::expression::FunctionOutput::Scalar(scalar),
218 // _,
219 // ) => {
220 // scalar >= rust_decimal::Decimal::ZERO
221 // && scalar <= rust_decimal::Decimal::ONE
222 // }
223 // (
224 // FunctionType::Vector,
225 // super::expression::FunctionOutput::Vector(vector),
226 // Some(length),
227 // ) => {
228 // let sum = vector.iter().sum::<rust_decimal::Decimal>();
229 // vector.len() == length as usize
230 // && sum >= rust_decimal::dec!(0.99)
231 // && sum <= rust_decimal::dec!(1.01)
232 // }
233 // (
234 // FunctionType::Vector,
235 // super::expression::FunctionOutput::Vector(vector),
236 // None,
237 // ) => {
238 // let sum = vector.iter().sum::<rust_decimal::Decimal>();
239 // sum >= rust_decimal::dec!(0.99)
240 // && sum <= rust_decimal::dec!(1.01)
241 // }
242 // _ => false,
243 // };
244
245 // // compiled output
246 // Ok(super::expression::CompiledFunctionOutput { output, valid })
247 // }
248
249 /// Computes the expected output length for a vector function.
250 ///
251 /// Evaluates the `output_length` expression to determine how many elements
252 /// the output vector should contain. This is only applicable to remote
253 /// vector functions which have an `output_length` field.
254 ///
255 /// # Arguments
256 ///
257 /// * `input` - The function input used to compute the output length
258 ///
259 /// # Returns
260 ///
261 /// - `Ok(Some(u64))` - The expected output length for remote vector functions
262 /// - `Ok(None)` - For scalar functions or inline functions
263 /// - `Err(ExpressionError)` - If the expression fails to compile
264 pub fn compile_output_length(
265 self,
266 input: &super::expression::InputValue,
267 ) -> Result<Option<u64>, super::expression::ExpressionError> {
268 let output_length_expr = match self {
269 Function::Remote(RemoteFunction::Scalar { .. }) => None,
270 Function::Remote(RemoteFunction::Vector {
271 output_length, ..
272 }) => Some(output_length),
273 Function::Inline(InlineFunction::Scalar { .. }) => None,
274 Function::Inline(InlineFunction::Vector { .. }) => None,
275 };
276 match output_length_expr {
277 Some(output_length_expr) => {
278 // prepare params for compiling output_length expression
279 let params = super::expression::Params::Ref(
280 super::expression::ParamsRef {
281 input,
282 output: None,
283 map: None,
284 tasks_min: None,
285 tasks_max: None,
286 depth: None,
287 name: None,
288 spec: None,
289 },
290 );
291 // compile output_length
292 let output_length = output_length_expr.compile_one(¶ms)?;
293 Ok(Some(output_length))
294 }
295 None => Ok(None),
296 }
297 }
298
299 /// Compiles the `input_split` expression to split input into multiple sub-inputs.
300 ///
301 /// Used by strategies like Swiss System that need to partition input into
302 /// smaller pools. The expression transforms the original input into an array
303 /// of inputs, where each element can be processed independently.
304 ///
305 /// # Arguments
306 ///
307 /// * `input` - The original function input to split
308 ///
309 /// # Returns
310 ///
311 /// - `Ok(Some(Vec<Input>))` - The split inputs for vector functions with `input_split` defined
312 /// - `Ok(None)` - For scalar functions or functions without `input_split`
313 /// - `Err(ExpressionError)` - If the expression fails to compile
314 pub fn compile_input_split(
315 self,
316 input: &super::expression::InputValue,
317 ) -> Result<
318 Option<Vec<super::expression::InputValue>>,
319 super::expression::ExpressionError,
320 > {
321 let input_split_expr = match self {
322 Function::Remote(RemoteFunction::Scalar { .. }) => None,
323 Function::Remote(RemoteFunction::Vector {
324 input_split, ..
325 }) => Some(input_split),
326 Function::Inline(InlineFunction::Scalar { .. }) => None,
327 Function::Inline(InlineFunction::Vector {
328 input_split, ..
329 }) => input_split,
330 };
331 match input_split_expr {
332 Some(input_split_expr) => {
333 // prepare params for compiling input_split expression
334 let params = super::expression::Params::Ref(
335 super::expression::ParamsRef {
336 input,
337 output: None,
338 map: None,
339 tasks_min: None,
340 tasks_max: None,
341 depth: None,
342 name: None,
343 spec: None,
344 },
345 );
346 // compile input_split
347 let input_split = input_split_expr.compile_one(¶ms)?;
348 Ok(Some(input_split))
349 }
350 None => Ok(None),
351 }
352 }
353
354 /// Compiles the `input_merge` expression to merge multiple sub-inputs back into one.
355 ///
356 /// Used by strategies like Swiss System to recombine a subset of split inputs
357 /// into a single input for pool execution. The expression transforms an array
358 /// of inputs (a subset from `input_split`) into a single merged input.
359 ///
360 /// # Arguments
361 ///
362 /// * `input` - An array of inputs to merge (typically a subset from `compile_input_split`)
363 ///
364 /// # Returns
365 ///
366 /// - `Ok(Some(Input))` - The merged input for vector functions with `input_merge` defined
367 /// - `Ok(None)` - For scalar functions or functions without `input_merge`
368 /// - `Err(ExpressionError)` - If the expression fails to compile
369 pub fn compile_input_merge(
370 self,
371 input: &super::expression::InputValue,
372 ) -> Result<
373 Option<super::expression::InputValue>,
374 super::expression::ExpressionError,
375 > {
376 let input_merge_expr = match self {
377 Function::Remote(RemoteFunction::Scalar { .. }) => None,
378 Function::Remote(RemoteFunction::Vector {
379 input_merge, ..
380 }) => Some(input_merge),
381 Function::Inline(InlineFunction::Scalar { .. }) => None,
382 Function::Inline(InlineFunction::Vector {
383 input_merge, ..
384 }) => input_merge,
385 };
386 match input_merge_expr {
387 Some(input_merge_expr) => {
388 // prepare params for compiling input_merge expression
389 let params = super::expression::Params::Ref(
390 super::expression::ParamsRef {
391 input,
392 output: None,
393 map: None,
394 tasks_min: None,
395 tasks_max: None,
396 depth: None,
397 name: None,
398 spec: None,
399 },
400 );
401 // compile input_merge
402 let input_merge = input_merge_expr.compile_one(¶ms)?;
403 Ok(Some(input_merge))
404 }
405 None => Ok(None),
406 }
407 }
408
409 /// Returns the function's description, if available.
410 pub fn description(&self) -> Option<&str> {
411 match self {
412 Function::Remote(remote_function) => {
413 Some(remote_function.description())
414 }
415 Function::Inline(_) => None,
416 }
417 }
418
419 /// Returns the function's input schema, if available.
420 pub fn input_schema(&self) -> Option<&super::expression::InputSchema> {
421 match self {
422 Function::Remote(remote_function) => {
423 Some(remote_function.input_schema())
424 }
425 Function::Inline(_) => None,
426 }
427 }
428
429 /// Returns the function's tasks.
430 pub fn tasks(&self) -> &[super::TaskExpression] {
431 match self {
432 Function::Remote(remote_function) => remote_function.tasks(),
433 Function::Inline(inline_function) => inline_function.tasks(),
434 }
435 }
436
437 /// Returns the function's expected output length expression, if defined.
438 pub fn output_length(&self) -> Option<&super::expression::Expression> {
439 match self {
440 Function::Remote(remote_function) => {
441 remote_function.output_length()
442 }
443 Function::Inline(_) => None,
444 }
445 }
446
447 /// Returns the function's input_split expression, if defined.
448 pub fn input_split(&self) -> Option<&super::expression::Expression> {
449 match self {
450 Function::Remote(remote_function) => remote_function.input_split(),
451 Function::Inline(inline_function) => inline_function.input_split(),
452 }
453 }
454
455 /// Returns the function's input_merge expression, if defined.
456 pub fn input_merge(&self) -> Option<&super::expression::Expression> {
457 match self {
458 Function::Remote(remote_function) => remote_function.input_merge(),
459 Function::Inline(inline_function) => inline_function.input_merge(),
460 }
461 }
462}
463
464/// A remote function with full metadata.
465///
466/// Remote functions are stored as `function.json` in repositories and
467/// referenced by `remote/owner/repository`. They include documentation fields
468/// that inline functions lack.
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
470#[serde(tag = "type")]
471#[schemars(rename = "functions.RemoteFunction")]
472pub enum RemoteFunction {
473 /// Produces a single score in [0, 1].
474 #[schemars(title = "Scalar")]
475 #[serde(rename = "scalar.function")]
476 Scalar {
477 /// Human-readable description of what the function does.
478 description: String,
479 /// JSON Schema defining the expected input structure.
480 input_schema: super::expression::InputSchema,
481 /// The list of tasks to execute. Tasks with a `map` expression are
482 /// expanded into multiple instances. Each instance is compiled with
483 /// `map` set to the current integer index.
484 /// Receives: `input`, `map` (if mapped).
485 tasks: Vec<super::TaskExpression>,
486 },
487 /// Produces a vector of scores that sums to 1.
488 #[schemars(title = "Vector")]
489 #[serde(rename = "vector.function")]
490 Vector {
491 /// Human-readable description of what the function does.
492 description: String,
493 /// JSON Schema defining the expected input structure.
494 input_schema: super::expression::InputSchema,
495 /// The list of tasks to execute. Tasks with a `map` expression are
496 /// expanded into multiple instances. Each instance is compiled with
497 /// `map` set to the current integer index.
498 /// Receives: `input`, `map` (if mapped).
499 tasks: Vec<super::TaskExpression>,
500 /// Expression computing the expected output vector length for task outputs.
501 /// Receives: `input`.
502 output_length: super::expression::Expression,
503 /// Expression transforming input into an input array of the output_length
504 /// When the Function is executed with any input from the array,
505 /// The output_length should be 1.
506 /// Receives: `input`.
507 input_split: super::expression::Expression,
508 /// Expression transforming an array of inputs computed by `input_split`
509 /// into a single Input object for the Function.
510 /// Receives: `input` (as an array).
511 input_merge: super::expression::Expression,
512 },
513}
514
515impl RemoteFunction {
516 /// Returns the function's description.
517 pub fn description(&self) -> &str {
518 match self {
519 RemoteFunction::Scalar { description, .. } => description,
520 RemoteFunction::Vector { description, .. } => description,
521 }
522 }
523
524 /// Returns the function's input schema.
525 pub fn input_schema(&self) -> &super::expression::InputSchema {
526 match self {
527 RemoteFunction::Scalar { input_schema, .. } => input_schema,
528 RemoteFunction::Vector { input_schema, .. } => input_schema,
529 }
530 }
531
532 /// Returns the function's tasks.
533 pub fn tasks(&self) -> &[super::TaskExpression] {
534 match self {
535 RemoteFunction::Scalar { tasks, .. } => tasks,
536 RemoteFunction::Vector { tasks, .. } => tasks,
537 }
538 }
539
540 /// Returns the function's expected output length, if defined (vector functions only).
541 pub fn output_length(&self) -> Option<&super::expression::Expression> {
542 match self {
543 RemoteFunction::Scalar { .. } => None,
544 RemoteFunction::Vector { output_length, .. } => Some(output_length),
545 }
546 }
547
548 /// Returns the function's input_split expression, if defined (vector functions only).
549 pub fn input_split(&self) -> Option<&super::expression::Expression> {
550 match self {
551 RemoteFunction::Scalar { .. } => None,
552 RemoteFunction::Vector { input_split, .. } => Some(input_split),
553 }
554 }
555
556 /// Returns the function's input_merge expression, if defined (vector functions only).
557 pub fn input_merge(&self) -> Option<&super::expression::Expression> {
558 match self {
559 RemoteFunction::Scalar { .. } => None,
560 RemoteFunction::Vector { input_merge, .. } => Some(input_merge),
561 }
562 }
563
564 pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
565 self.tasks().iter().filter_map(|task| match task {
566 super::TaskExpression::ScalarFunction(t) => Some(&t.path),
567 super::TaskExpression::VectorFunction(t) => Some(&t.path),
568 _ => None,
569 })
570 }
571}
572
573/// An inline function definition without metadata.
574///
575/// Used when embedding function logic directly in requests rather than
576/// referencing a remote function. Lacks description and input
577/// schema fields.
578#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
579#[serde(tag = "type")]
580#[schemars(rename = "functions.InlineFunction")]
581pub enum InlineFunction {
582 /// Produces a single score in [0, 1].
583 #[schemars(title = "Scalar")]
584 #[serde(rename = "scalar.function")]
585 Scalar {
586 /// The list of tasks to execute. Tasks with a `map` expression are
587 /// expanded into multiple instances. Each instance is compiled with
588 /// `map` set to the current integer index.
589 /// Receives: `input`, `map` (if mapped).
590 tasks: Vec<super::TaskExpression>,
591 },
592 /// Produces a vector of scores that sums to 1.
593 #[schemars(title = "Vector")]
594 #[serde(rename = "vector.function")]
595 Vector {
596 /// The list of tasks to execute. Tasks with a `map` expression are
597 /// expanded into multiple instances. Each instance is compiled with
598 /// `map` set to the current integer index.
599 /// Receives: `input`, `map` (if mapped).
600 tasks: Vec<super::TaskExpression>,
601 /// Expression transforming input into an input array of the output_length
602 /// When the Function is executed with any input from the array,
603 /// The output_length should be 1.
604 /// Receives: `input`.
605 /// Only required if the request uses a strategy that needs input splitting.
606 input_split: Option<super::expression::Expression>,
607 /// Expression transforming an array of inputs computed by `input_split`
608 /// into a single Input object for the Function.
609 /// Receives: `input` (as an array).
610 /// Only required if the request uses a strategy that needs input splitting.
611 input_merge: Option<super::expression::Expression>,
612 },
613}
614
615impl InlineFunction {
616 /// Returns the function's tasks.
617 pub fn tasks(&self) -> &[super::TaskExpression] {
618 match self {
619 InlineFunction::Scalar { tasks, .. } => tasks,
620 InlineFunction::Vector { tasks, .. } => tasks,
621 }
622 }
623
624 /// Returns the function's input_split expression, if defined (vector functions only).
625 pub fn input_split(&self) -> Option<&super::expression::Expression> {
626 match self {
627 InlineFunction::Scalar { .. } => None,
628 InlineFunction::Vector { input_split, .. } => input_split.as_ref(),
629 }
630 }
631
632 /// Returns the function's input_merge expression, if defined (vector functions only).
633 pub fn input_merge(&self) -> Option<&super::expression::Expression> {
634 match self {
635 InlineFunction::Scalar { .. } => None,
636 InlineFunction::Vector { input_merge, .. } => input_merge.as_ref(),
637 }
638 }
639
640 pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
641 self.tasks().iter().filter_map(|task| match task {
642 super::TaskExpression::ScalarFunction(t) => Some(&t.path),
643 super::TaskExpression::VectorFunction(t) => Some(&t.path),
644 _ => None,
645 })
646 }
647}
648
649#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
650#[schemars(rename = "functions.FunctionType")]
651pub enum FunctionType {
652 #[schemars(title = "Scalar")]
653 #[serde(rename = "scalar.function")]
654 Scalar,
655 #[schemars(title = "Vector")]
656 #[serde(rename = "vector.function")]
657 Vector,
658}