objectiveai_sdk/functions/function.rs
1//! Function types and client-side compilation.
2//!
3//! # Output Computation
4//!
5//! Functions do **not** have a top-level output expression. Instead, each task has its
6//! own `output` expression that transforms its raw result into a [`TaskOutputOwned`].
7//! The function's final output is computed as a **weighted average** of all task outputs
8//! using profile weights.
9//!
10//! - If a function has only 1 task, that task's output becomes the function's output directly
11//! - If a function has multiple tasks, each task's output is weighted and averaged
12//!
13//! Each task's `output` expression must return a valid `TaskOutputOwned` for the function's type:
14//! - **Scalar functions**: each task must return `Scalar(value)` where value is in [0, 1]
15//! - **Vector functions**: each task must return `Vector(values)` where values sum to ~1
16//!
17//! [`TaskOutputOwned`]: super::expression::TaskOutputOwned
18
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21
22/// A Function definition, either remote or inline.
23///
24/// Functions are composable scoring pipelines that transform structured input
25/// into scores. Each task has an `output` expression that transforms its raw result
26/// into a `TaskOutputOwned`. The function's final output is the weighted average of
27/// all task outputs using profile weights.
28///
29/// Use [`compile_tasks`](Self::compile_tasks) to preview how task expressions resolve
30/// for given inputs.
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
32#[serde(untagged)]
33#[schemars(rename = "functions.Function")]
34pub enum Function {
35 /// A remote function with metadata (description, schema, etc.).
36 #[schemars(title = "Remote")]
37 Remote(RemoteFunction),
38 /// An inline function definition without metadata.
39 #[schemars(title = "Inline")]
40 Inline(InlineFunction),
41}
42
43impl Function {
44 /// Validates the input against the function's input schema.
45 ///
46 /// For remote functions, checks whether the provided input conforms to
47 /// the function's JSON Schema definition. For inline functions, returns
48 /// `None` since they lack schema definitions.
49 ///
50 /// # Returns
51 ///
52 /// - `Some(true)` if the input is valid against the schema
53 /// - `Some(false)` if the input is invalid
54 /// - `None` for inline functions (no schema to validate against)
55 pub fn validate_input(
56 &self,
57 input: &super::expression::InputValue,
58 ) -> Option<bool> {
59 match self {
60 Function::Remote(remote_function) => {
61 Some(remote_function.input_schema().validate_input(input))
62 }
63 Function::Inline(_) => None,
64 }
65 }
66
67 /// Compiles task expressions to show the final tasks for a given input.
68 ///
69 /// Evaluates all expressions (JMESPath or Starlark) in the function's tasks
70 /// using the provided input data. Tasks with `skip` expressions that evaluate
71 /// to true return `None`. Tasks with `map` fields produce multiple task instances.
72 ///
73 /// # Returns
74 ///
75 /// A vector where each element corresponds to a task definition:
76 /// - `None` if the task was skipped
77 /// - `Some(CompiledTask::One(...))` for non-mapped tasks
78 /// - `Some(CompiledTask::Many(...))` for mapped tasks
79 pub fn compile_tasks(
80 self,
81 input: &super::expression::InputValue,
82 ) -> Result<
83 Vec<Option<super::CompiledTask>>,
84 super::expression::ExpressionError,
85 > {
86 // extract task expressions
87 let task_exprs = match self {
88 Function::Remote(RemoteFunction::Scalar { tasks, .. }) => tasks,
89 Function::Remote(RemoteFunction::Vector { tasks, .. }) => tasks,
90 Function::Inline(InlineFunction::Scalar { tasks, .. }) => tasks,
91 Function::Inline(InlineFunction::Vector { tasks, .. }) => tasks,
92 };
93
94 // prepare params for compiling expressions
95 let mut params =
96 super::expression::Params::Ref(super::expression::ParamsRef {
97 input,
98 output: None,
99 map: None,
100 });
101
102 // compile tasks
103 let mut tasks = Vec::with_capacity(task_exprs.len());
104 for mut task_expr in task_exprs {
105 tasks.push(
106 if let Some(skip_expr) = task_expr.take_skip()
107 && skip_expr.compile_one::<bool>(¶ms)?
108 {
109 // None if task is skipped
110 None
111 } else if let Some(map_expr) = task_expr.map() {
112 // evaluate map expression to get count
113 let count: u64 = map_expr.compile_one(¶ms)?;
114 // compile task for each map index
115 let mut map_tasks = Vec::with_capacity(count as usize);
116 for i in 0..count {
117 // set map index
118 match &mut params {
119 super::expression::Params::Ref(params_ref) => {
120 params_ref.map = Some(i);
121 }
122 _ => unreachable!(),
123 }
124 // compile task with map index
125 map_tasks.push(task_expr.clone().compile(¶ms)?);
126 // reset map index
127 match &mut params {
128 super::expression::Params::Ref(params_ref) => {
129 params_ref.map = None;
130 }
131 _ => unreachable!(),
132 }
133 }
134 Some(super::CompiledTask::Many(map_tasks))
135 } else {
136 // compile single task
137 Some(super::CompiledTask::One(task_expr.compile(¶ms)?))
138 },
139 );
140 }
141
142 // compiled tasks
143 Ok(tasks)
144 }
145
146 // /// Computes the final output given input and task outputs.
147 // ///
148 // /// Evaluates the function's output expression using the provided input data
149 // /// and task results. Also validates that the output meets constraints:
150 // /// - Scalar functions: output must be in [0, 1]
151 // /// - Vector functions: output must sum to approximately 1
152 // pub fn compile_output(
153 // self,
154 // input: &super::expression::InputValue,
155 // task_outputs: &[Option<super::expression::TaskOutput>],
156 // ) -> Result<
157 // super::expression::CompiledFunctionOutput,
158 // super::expression::ExpressionError,
159 // > {
160 // #[derive(Clone, Copy)]
161 // enum FunctionType {
162 // Scalar,
163 // Vector,
164 // }
165
166 // // prepare params for compiling output_length expression
167 // let mut params =
168 // super::expression::Params::Ref(super::expression::ParamsRef {
169 // input,
170 // output: None,
171 // map: None,
172 // });
173
174 // // extract output expression and output_length
175 // let (function_type, output_expr, output_length) = match self {
176 // Function::Remote(RemoteFunction::Scalar { output, .. }) => {
177 // (FunctionType::Scalar, output, None)
178 // }
179 // Function::Remote(RemoteFunction::Vector {
180 // output,
181 // output_length,
182 // ..
183 // }) => (
184 // FunctionType::Vector,
185 // output,
186 // Some(output_length.compile_one(¶ms)?),
187 // ),
188 // Function::Inline(InlineFunction::Scalar { output, .. }) => {
189 // (FunctionType::Scalar, output, None)
190 // }
191 // Function::Inline(InlineFunction::Vector { output, .. }) => {
192 // (FunctionType::Vector, output, None)
193 // }
194 // };
195
196 // // prepare params for compiling output expression
197 // match &mut params {
198 // super::expression::Params::Ref(params_ref) => {
199 // params_ref.tasks = task_outputs;
200 // }
201 // _ => unreachable!(),
202 // }
203
204 // // compile output
205 // let output = output_expr
206 // .compile_one::<super::expression::FunctionOutput>(¶ms)?;
207
208 // // validate output
209 // let valid = match (function_type, &output, output_length) {
210 // (
211 // FunctionType::Scalar,
212 // &super::expression::FunctionOutput::Scalar(scalar),
213 // _,
214 // ) => {
215 // scalar >= rust_decimal::Decimal::ZERO
216 // && scalar <= rust_decimal::Decimal::ONE
217 // }
218 // (
219 // FunctionType::Vector,
220 // super::expression::FunctionOutput::Vector(vector),
221 // Some(length),
222 // ) => {
223 // let sum = vector.iter().sum::<rust_decimal::Decimal>();
224 // vector.len() == length as usize
225 // && sum >= rust_decimal::dec!(0.99)
226 // && sum <= rust_decimal::dec!(1.01)
227 // }
228 // (
229 // FunctionType::Vector,
230 // super::expression::FunctionOutput::Vector(vector),
231 // None,
232 // ) => {
233 // let sum = vector.iter().sum::<rust_decimal::Decimal>();
234 // sum >= rust_decimal::dec!(0.99)
235 // && sum <= rust_decimal::dec!(1.01)
236 // }
237 // _ => false,
238 // };
239
240 // // compiled output
241 // Ok(super::expression::CompiledFunctionOutput { output, valid })
242 // }
243
244 /// Computes the expected output length for a vector function.
245 ///
246 /// Evaluates the `output_length` expression to determine how many elements
247 /// the output vector should contain. This is only applicable to remote
248 /// vector functions which have an `output_length` field.
249 ///
250 /// # Arguments
251 ///
252 /// * `input` - The function input used to compute the output length
253 ///
254 /// # Returns
255 ///
256 /// - `Ok(Some(u64))` - The expected output length for remote vector functions
257 /// - `Ok(None)` - For scalar functions or inline functions
258 /// - `Err(ExpressionError)` - If the expression fails to compile
259 pub fn compile_output_length(
260 self,
261 input: &super::expression::InputValue,
262 ) -> Result<Option<u64>, super::expression::ExpressionError> {
263 let output_length_expr = match self {
264 Function::Remote(RemoteFunction::Scalar { .. }) => None,
265 Function::Remote(RemoteFunction::Vector {
266 output_length, ..
267 }) => Some(output_length),
268 Function::Inline(InlineFunction::Scalar { .. }) => None,
269 Function::Inline(InlineFunction::Vector { .. }) => None,
270 };
271 match output_length_expr {
272 Some(output_length_expr) => {
273 // prepare params for compiling output_length expression
274 let params = super::expression::Params::Ref(
275 super::expression::ParamsRef {
276 input,
277 output: None,
278 map: None,
279 },
280 );
281 // compile output_length
282 let output_length = output_length_expr.compile_one(¶ms)?;
283 Ok(Some(output_length))
284 }
285 None => Ok(None),
286 }
287 }
288
289 /// Compiles the `input_split` expression to split input into multiple sub-inputs.
290 ///
291 /// Used by strategies like Swiss System that need to partition input into
292 /// smaller pools. The expression transforms the original input into an array
293 /// of inputs, where each element can be processed independently.
294 ///
295 /// # Arguments
296 ///
297 /// * `input` - The original function input to split
298 ///
299 /// # Returns
300 ///
301 /// - `Ok(Some(Vec<Input>))` - The split inputs for vector functions with `input_split` defined
302 /// - `Ok(None)` - For scalar functions or functions without `input_split`
303 /// - `Err(ExpressionError)` - If the expression fails to compile
304 pub fn compile_input_split(
305 self,
306 input: &super::expression::InputValue,
307 ) -> Result<
308 Option<Vec<super::expression::InputValue>>,
309 super::expression::ExpressionError,
310 > {
311 let input_split_expr = match self {
312 Function::Remote(RemoteFunction::Scalar { .. }) => None,
313 Function::Remote(RemoteFunction::Vector {
314 input_split, ..
315 }) => Some(input_split),
316 Function::Inline(InlineFunction::Scalar { .. }) => None,
317 Function::Inline(InlineFunction::Vector {
318 input_split, ..
319 }) => input_split,
320 };
321 match input_split_expr {
322 Some(input_split_expr) => {
323 // prepare params for compiling input_split expression
324 let params = super::expression::Params::Ref(
325 super::expression::ParamsRef {
326 input,
327 output: None,
328 map: None,
329 },
330 );
331 // compile input_split
332 let input_split = input_split_expr.compile_one(¶ms)?;
333 Ok(Some(input_split))
334 }
335 None => Ok(None),
336 }
337 }
338
339 /// Compiles the `input_merge` expression to merge multiple sub-inputs back into one.
340 ///
341 /// Used by strategies like Swiss System to recombine a subset of split inputs
342 /// into a single input for pool execution. The expression transforms an array
343 /// of inputs (a subset from `input_split`) into a single merged input.
344 ///
345 /// # Arguments
346 ///
347 /// * `input` - An array of inputs to merge (typically a subset from `compile_input_split`)
348 ///
349 /// # Returns
350 ///
351 /// - `Ok(Some(Input))` - The merged input for vector functions with `input_merge` defined
352 /// - `Ok(None)` - For scalar functions or functions without `input_merge`
353 /// - `Err(ExpressionError)` - If the expression fails to compile
354 pub fn compile_input_merge(
355 self,
356 input: &super::expression::InputValue,
357 ) -> Result<
358 Option<super::expression::InputValue>,
359 super::expression::ExpressionError,
360 > {
361 let input_merge_expr = match self {
362 Function::Remote(RemoteFunction::Scalar { .. }) => None,
363 Function::Remote(RemoteFunction::Vector {
364 input_merge, ..
365 }) => Some(input_merge),
366 Function::Inline(InlineFunction::Scalar { .. }) => None,
367 Function::Inline(InlineFunction::Vector {
368 input_merge, ..
369 }) => input_merge,
370 };
371 match input_merge_expr {
372 Some(input_merge_expr) => {
373 // prepare params for compiling input_merge expression
374 let params = super::expression::Params::Ref(
375 super::expression::ParamsRef {
376 input,
377 output: None,
378 map: None,
379 },
380 );
381 // compile input_merge
382 let input_merge = input_merge_expr.compile_one(¶ms)?;
383 Ok(Some(input_merge))
384 }
385 None => Ok(None),
386 }
387 }
388
389 /// Returns the function's description, if available.
390 pub fn description(&self) -> Option<&str> {
391 match self {
392 Function::Remote(remote_function) => {
393 Some(remote_function.description())
394 }
395 Function::Inline(_) => None,
396 }
397 }
398
399 /// Returns the function's input schema, if available.
400 pub fn input_schema(&self) -> Option<&super::expression::InputSchema> {
401 match self {
402 Function::Remote(remote_function) => {
403 Some(remote_function.input_schema())
404 }
405 Function::Inline(_) => None,
406 }
407 }
408
409 /// Returns the function's tasks.
410 pub fn tasks(&self) -> &[super::TaskExpression] {
411 match self {
412 Function::Remote(remote_function) => remote_function.tasks(),
413 Function::Inline(inline_function) => inline_function.tasks(),
414 }
415 }
416
417 /// Returns the function's expected output length expression, if defined.
418 pub fn output_length(&self) -> Option<&super::expression::Expression> {
419 match self {
420 Function::Remote(remote_function) => {
421 remote_function.output_length()
422 }
423 Function::Inline(_) => None,
424 }
425 }
426
427 /// Returns the function's input_split expression, if defined.
428 pub fn input_split(&self) -> Option<&super::expression::Expression> {
429 match self {
430 Function::Remote(remote_function) => remote_function.input_split(),
431 Function::Inline(inline_function) => inline_function.input_split(),
432 }
433 }
434
435 /// Returns the function's input_merge expression, if defined.
436 pub fn input_merge(&self) -> Option<&super::expression::Expression> {
437 match self {
438 Function::Remote(remote_function) => remote_function.input_merge(),
439 Function::Inline(inline_function) => inline_function.input_merge(),
440 }
441 }
442}
443
444/// A remote function with full metadata.
445///
446/// Remote functions are stored as `function.json` in repositories and
447/// referenced by `remote/owner/repository`. They include documentation fields
448/// that inline functions lack.
449#[derive(
450 Debug,
451 Clone,
452 PartialEq,
453 Serialize,
454 Deserialize,
455 JsonSchema,
456 arbitrary::Arbitrary,
457)]
458#[serde(tag = "type")]
459#[schemars(rename = "functions.RemoteFunction")]
460pub enum RemoteFunction {
461 /// Produces a single score in [0, 1].
462 #[schemars(title = "Scalar")]
463 #[serde(rename = "scalar.function")]
464 Scalar {
465 /// Human-readable description of what the function does.
466 description: String,
467 /// JSON Schema defining the expected input structure.
468 input_schema: super::expression::InputSchema,
469 /// The list of tasks to execute. Tasks with a `map` expression are
470 /// expanded into multiple instances. Each instance is compiled with
471 /// `map` set to the current integer index.
472 /// Receives: `input`, `map` (if mapped).
473 tasks: Vec<super::TaskExpression>,
474 },
475 /// Produces a vector of scores that sums to 1.
476 #[schemars(title = "Vector")]
477 #[serde(rename = "vector.function")]
478 Vector {
479 /// Human-readable description of what the function does.
480 description: String,
481 /// JSON Schema defining the expected input structure.
482 input_schema: super::expression::InputSchema,
483 /// The list of tasks to execute. Tasks with a `map` expression are
484 /// expanded into multiple instances. Each instance is compiled with
485 /// `map` set to the current integer index.
486 /// Receives: `input`, `map` (if mapped).
487 tasks: Vec<super::TaskExpression>,
488 /// Expression computing the expected output vector length for task outputs.
489 /// Receives: `input`.
490 output_length: super::expression::Expression,
491 /// Expression transforming input into an input array of the output_length
492 /// When the Function is executed with any input from the array,
493 /// The output_length should be 1.
494 /// Receives: `input`.
495 input_split: super::expression::Expression,
496 /// Expression transforming an array of inputs computed by `input_split`
497 /// into a single Input object for the Function.
498 /// Receives: `input` (as an array).
499 input_merge: super::expression::Expression,
500 },
501}
502
503impl RemoteFunction {
504 /// Returns the function's description.
505 pub fn description(&self) -> &str {
506 match self {
507 RemoteFunction::Scalar { description, .. } => description,
508 RemoteFunction::Vector { description, .. } => description,
509 }
510 }
511
512 /// Returns the function's input schema.
513 pub fn input_schema(&self) -> &super::expression::InputSchema {
514 match self {
515 RemoteFunction::Scalar { input_schema, .. } => input_schema,
516 RemoteFunction::Vector { input_schema, .. } => input_schema,
517 }
518 }
519
520 /// Returns the function's tasks.
521 pub fn tasks(&self) -> &[super::TaskExpression] {
522 match self {
523 RemoteFunction::Scalar { tasks, .. } => tasks,
524 RemoteFunction::Vector { tasks, .. } => tasks,
525 }
526 }
527
528 /// Returns the function's expected output length, if defined (vector functions only).
529 pub fn output_length(&self) -> Option<&super::expression::Expression> {
530 match self {
531 RemoteFunction::Scalar { .. } => None,
532 RemoteFunction::Vector { output_length, .. } => Some(output_length),
533 }
534 }
535
536 /// Returns the function's input_split expression, if defined (vector functions only).
537 pub fn input_split(&self) -> Option<&super::expression::Expression> {
538 match self {
539 RemoteFunction::Scalar { .. } => None,
540 RemoteFunction::Vector { input_split, .. } => Some(input_split),
541 }
542 }
543
544 /// Returns the function's input_merge expression, if defined (vector functions only).
545 pub fn input_merge(&self) -> Option<&super::expression::Expression> {
546 match self {
547 RemoteFunction::Scalar { .. } => None,
548 RemoteFunction::Vector { input_merge, .. } => Some(input_merge),
549 }
550 }
551
552 pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
553 self.tasks().iter().filter_map(|task| match task {
554 super::TaskExpression::ScalarFunction(t) => Some(&t.path),
555 super::TaskExpression::VectorFunction(t) => Some(&t.path),
556 _ => None,
557 })
558 }
559}
560
561/// An inline function definition without metadata.
562///
563/// Used when embedding function logic directly in requests rather than
564/// referencing a remote function. Lacks description and input
565/// schema fields.
566#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
567#[serde(tag = "type")]
568#[schemars(rename = "functions.InlineFunction")]
569pub enum InlineFunction {
570 /// Produces a single score in [0, 1].
571 #[schemars(title = "Scalar")]
572 #[serde(rename = "scalar.function")]
573 Scalar {
574 /// The list of tasks to execute. Tasks with a `map` expression are
575 /// expanded into multiple instances. Each instance is compiled with
576 /// `map` set to the current integer index.
577 /// Receives: `input`, `map` (if mapped).
578 tasks: Vec<super::TaskExpression>,
579 },
580 /// Produces a vector of scores that sums to 1.
581 #[schemars(title = "Vector")]
582 #[serde(rename = "vector.function")]
583 Vector {
584 /// The list of tasks to execute. Tasks with a `map` expression are
585 /// expanded into multiple instances. Each instance is compiled with
586 /// `map` set to the current integer index.
587 /// Receives: `input`, `map` (if mapped).
588 tasks: Vec<super::TaskExpression>,
589 /// Expression transforming input into an input array of the output_length
590 /// When the Function is executed with any input from the array,
591 /// The output_length should be 1.
592 /// Receives: `input`.
593 /// Only required if the request uses a strategy that needs input splitting.
594 input_split: Option<super::expression::Expression>,
595 /// Expression transforming an array of inputs computed by `input_split`
596 /// into a single Input object for the Function.
597 /// Receives: `input` (as an array).
598 /// Only required if the request uses a strategy that needs input splitting.
599 input_merge: Option<super::expression::Expression>,
600 },
601}
602
603impl InlineFunction {
604 /// Returns the function's tasks.
605 pub fn tasks(&self) -> &[super::TaskExpression] {
606 match self {
607 InlineFunction::Scalar { tasks, .. } => tasks,
608 InlineFunction::Vector { tasks, .. } => tasks,
609 }
610 }
611
612 /// Returns the function's input_split expression, if defined (vector functions only).
613 pub fn input_split(&self) -> Option<&super::expression::Expression> {
614 match self {
615 InlineFunction::Scalar { .. } => None,
616 InlineFunction::Vector { input_split, .. } => input_split.as_ref(),
617 }
618 }
619
620 /// Returns the function's input_merge expression, if defined (vector functions only).
621 pub fn input_merge(&self) -> Option<&super::expression::Expression> {
622 match self {
623 InlineFunction::Scalar { .. } => None,
624 InlineFunction::Vector { input_merge, .. } => input_merge.as_ref(),
625 }
626 }
627
628 pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
629 self.tasks().iter().filter_map(|task| match task {
630 super::TaskExpression::ScalarFunction(t) => Some(&t.path),
631 super::TaskExpression::VectorFunction(t) => Some(&t.path),
632 _ => None,
633 })
634 }
635}
636
637#[derive(
638 Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema,
639)]
640#[schemars(rename = "functions.FunctionType")]
641pub enum FunctionType {
642 #[schemars(title = "Scalar")]
643 #[serde(rename = "scalar.function")]
644 Scalar,
645 #[schemars(title = "Vector")]
646 #[serde(rename = "vector.function")]
647 Vector,
648}