Skip to main content

objectiveai_sdk/functions/alpha_scalar/
task.rs

1use crate::{agent, functions};
2use serde::{Deserialize, Serialize};
3use schemars::JsonSchema;
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
6#[serde(tag = "type")]
7#[schemars(rename = "functions.alpha_scalar.BranchTaskExpression")]
8pub enum BranchTaskExpression {
9    #[schemars(title = "ScalarFunction")]
10    #[serde(rename = "alpha.scalar.function")]
11    ScalarFunction(ScalarFunctionTaskExpression),
12    #[schemars(title = "PlaceholderScalarFunction")]
13    #[serde(rename = "placeholder.alpha.scalar.function")]
14    PlaceholderScalarFunction(PlaceholderScalarFunctionTaskExpression),
15}
16
17impl BranchTaskExpression {
18    pub fn url(&self) -> Option<String> {
19        match self {
20            BranchTaskExpression::ScalarFunction(task) => Some(task.url()),
21            BranchTaskExpression::PlaceholderScalarFunction(_) => None,
22        }
23    }
24
25    pub fn transpile(self) -> functions::TaskExpression {
26        match self {
27            BranchTaskExpression::ScalarFunction(task) => {
28                functions::TaskExpression::ScalarFunction(task.transpile())
29            }
30            BranchTaskExpression::PlaceholderScalarFunction(task) => {
31                functions::TaskExpression::PlaceholderScalarFunction(
32                    task.transpile(),
33                )
34            }
35        }
36    }
37
38    pub fn is_placeholder(&self) -> bool {
39        match self {
40            BranchTaskExpression::ScalarFunction(_) => false,
41            BranchTaskExpression::PlaceholderScalarFunction(_) => true,
42        }
43    }
44}
45
46#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
47#[serde(tag = "type")]
48#[schemars(rename = "functions.alpha_scalar.PartialPlaceholderBranchTaskExpression")]
49pub enum PartialPlaceholderBranchTaskExpression {
50    #[serde(rename = "placeholder.alpha.scalar.function")]
51    PlaceholderScalarFunction(PartialPlaceholderScalarFunctionTaskExpression),
52}
53
54impl PartialPlaceholderBranchTaskExpression {
55    pub fn complete(
56        self,
57        name: String,
58        depth: u64,
59        min_branch_width: u64,
60        max_branch_width: u64,
61        min_leaf_width: u64,
62        max_leaf_width: u64,
63    ) -> BranchTaskExpression {
64        match self {
65            PartialPlaceholderBranchTaskExpression::PlaceholderScalarFunction(
66                task,
67            ) => BranchTaskExpression::PlaceholderScalarFunction(
68                task.complete(
69                    name,
70                    depth,
71                    min_branch_width,
72                    max_branch_width,
73                    min_leaf_width,
74                    max_leaf_width,
75                ),
76            ),
77        }
78    }
79}
80
81#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
82#[serde(tag = "type")]
83#[schemars(rename = "functions.alpha_scalar.LeafTaskExpression")]
84pub enum LeafTaskExpression {
85    #[serde(rename = "vector.completion")]
86    VectorCompletion(VectorCompletionTaskExpression),
87}
88
89impl LeafTaskExpression {
90    pub fn transpile(self) -> functions::TaskExpression {
91        match self {
92            LeafTaskExpression::VectorCompletion(task) => {
93                functions::TaskExpression::VectorCompletion(task.transpile())
94            }
95        }
96    }
97}
98
99#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
100#[schemars(rename = "functions.alpha_scalar.ScalarFunctionTaskExpression")]
101pub struct ScalarFunctionTaskExpression {
102    #[serde(flatten)]
103    #[schemars(schema_with = "crate::flatten_schema::<crate::RemotePath>")]
104    pub path: crate::RemotePath,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    #[schemars(extend("omitempty" = true))]
107    pub skip: Option<functions::expression::Expression>,
108    pub input: super::expression::ScalarFunctionInputValueExpression,
109}
110
111impl ScalarFunctionTaskExpression {
112    pub fn url(&self) -> String {
113        self.path.url()
114    }
115
116    pub fn transpile(self) -> functions::ScalarFunctionTaskExpression {
117        functions::ScalarFunctionTaskExpression {
118            path: self.path,
119            skip: self.skip,
120            map: None,
121            input:
122                super::expression::scalar_function_input_value_expression::transpile(
123                    self.input,
124                ),
125            output: functions::expression::Expression::Special(
126                functions::expression::Special::Output,
127            ),
128        }
129    }
130}
131
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
133#[schemars(rename = "functions.alpha_scalar.PlaceholderScalarFunctionTaskExpression")]
134pub struct PlaceholderScalarFunctionTaskExpression {
135    #[serde(flatten)]
136    pub params: functions::inventions::Params,
137    pub input_schema: super::expression::ScalarFunctionInputSchema,
138    #[serde(skip_serializing_if = "Option::is_none")]
139    #[schemars(extend("omitempty" = true))]
140    pub skip: Option<functions::expression::Expression>,
141    pub input: super::expression::ScalarFunctionInputValueExpression,
142}
143
144impl PlaceholderScalarFunctionTaskExpression {
145    pub fn transpile(
146        self,
147    ) -> functions::PlaceholderScalarFunctionTaskExpression {
148        functions::PlaceholderScalarFunctionTaskExpression {
149            input_schema:
150                super::expression::scalar_function_input_schema::transpile(
151                    self.input_schema,
152                ),
153            skip: self.skip,
154            map: None,
155            input:
156                super::expression::scalar_function_input_value_expression::transpile(
157                    self.input,
158                ),
159            output: functions::expression::Expression::Special(
160                functions::expression::Special::Output,
161            ),
162        }
163    }
164
165    pub fn replace(
166        self,
167        path: &crate::RemotePath,
168    ) -> ScalarFunctionTaskExpression {
169        ScalarFunctionTaskExpression {
170            path: path.clone(),
171            skip: self.skip,
172            input: self.input,
173        }
174    }
175}
176
177#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
178#[schemars(rename = "functions.alpha_scalar.PartialPlaceholderScalarFunctionTaskExpression")]
179pub struct PartialPlaceholderScalarFunctionTaskExpression {
180    pub spec: String,
181    pub input_schema: super::expression::ScalarFunctionInputSchema,
182    #[serde(skip_serializing_if = "Option::is_none")]
183    #[schemars(extend("omitempty" = true))]
184    pub skip: Option<functions::expression::Expression>,
185    pub input: super::expression::ScalarFunctionInputValueExpression,
186}
187
188impl PartialPlaceholderScalarFunctionTaskExpression {
189    pub fn complete(
190        self,
191        name: String,
192        depth: u64,
193        min_branch_width: u64,
194        max_branch_width: u64,
195        min_leaf_width: u64,
196        max_leaf_width: u64,
197    ) -> PlaceholderScalarFunctionTaskExpression {
198        PlaceholderScalarFunctionTaskExpression {
199            params: functions::inventions::Params {
200                depth,
201                min_branch_width,
202                max_branch_width,
203                min_leaf_width,
204                max_leaf_width,
205                name,
206                spec: self.spec,
207            },
208            input_schema: self.input_schema,
209            skip: self.skip,
210            input: self.input,
211        }
212    }
213}
214
215#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
216#[schemars(rename = "functions.alpha_scalar.VectorCompletionTaskExpression")]
217pub struct VectorCompletionTaskExpression {
218    #[serde(skip_serializing_if = "Option::is_none")]
219    #[schemars(extend("omitempty" = true))]
220    pub skip: Option<functions::expression::Expression>,
221    pub messages: functions::expression::Expression,
222    pub responses: Vec<agent::completions::message::RichContent>,
223}
224
225impl VectorCompletionTaskExpression {
226    pub fn transpile(self) -> functions::VectorCompletionTaskExpression {
227        functions::VectorCompletionTaskExpression {
228            skip: self.skip,
229            map: None,
230            messages: functions::expression::WithExpression::Expression(
231                self.messages,
232            ),
233            responses: functions::expression::WithExpression::Value(
234                self.responses
235                    .into_iter()
236                    .map(agent::completions::message::RichContentExpression::from)
237                    .map(functions::expression::WithExpression::Value)
238                    .collect(),
239            ),
240            output: functions::expression::Expression::Special(
241                functions::expression::Special::TaskOutputWeightedSum,
242            ),
243        }
244    }
245}