Skip to main content

objectiveai_sdk/functions/alpha_vector/
function.rs

1use crate::functions;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4
5#[derive(
6    Debug,
7    Clone,
8    PartialEq,
9    Serialize,
10    Deserialize,
11    JsonSchema,
12    arbitrary::Arbitrary,
13)]
14#[serde(tag = "type")]
15#[schemars(rename = "functions.alpha_vector.RemoteFunction")]
16pub enum RemoteFunction {
17    #[schemars(title = "Branch")]
18    #[serde(rename = "alpha.vector.branch.function")]
19    Branch {
20        description: String,
21        input_schema: super::expression::VectorFunctionInputSchema,
22        tasks: Vec<super::BranchTaskExpression>,
23    },
24    #[schemars(title = "Leaf")]
25    #[serde(rename = "alpha.vector.leaf.function")]
26    Leaf {
27        description: String,
28        input_schema: super::expression::VectorFunctionInputSchema,
29        tasks: Vec<super::LeafTaskExpression>,
30    },
31}
32
33impl RemoteFunction {
34    pub fn tasks(&self) -> &[super::BranchTaskExpression] {
35        match self {
36            RemoteFunction::Branch { tasks, .. } => tasks,
37            RemoteFunction::Leaf { .. } => &[],
38        }
39    }
40
41    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
42        self.tasks().iter().filter_map(|task| match task {
43            super::BranchTaskExpression::ScalarFunction(t) => Some(&t.path),
44            super::BranchTaskExpression::VectorFunction(t) => Some(&t.path),
45            _ => None,
46        })
47    }
48
49    pub fn transpile(self) -> functions::RemoteFunction {
50        match self {
51            RemoteFunction::Branch {
52                description,
53                input_schema,
54                tasks,
55            } => functions::RemoteFunction::Vector {
56                description,
57                input_schema: input_schema.transpile(),
58                tasks: tasks
59                    .into_iter()
60                    .map(super::BranchTaskExpression::transpile)
61                    .collect(),
62                output_length: functions::expression::Expression::Special(
63                    functions::expression::Special::InputItemsOutputLength,
64                ),
65                input_split: functions::expression::Expression::Special(
66                    functions::expression::Special::InputItemsOptionalContextSplit,
67                ),
68                input_merge: functions::expression::Expression::Special(
69                    functions::expression::Special::InputItemsOptionalContextMerge,
70                ),
71            },
72            RemoteFunction::Leaf {
73                description,
74                input_schema,
75                tasks,
76            } => functions::RemoteFunction::Vector {
77                description,
78                input_schema: input_schema.transpile(),
79                tasks: tasks
80                    .into_iter()
81                    .map(super::LeafTaskExpression::transpile)
82                    .collect(),
83                output_length: functions::expression::Expression::Special(
84                    functions::expression::Special::InputItemsOutputLength,
85                ),
86                input_split: functions::expression::Expression::Special(
87                    functions::expression::Special::InputItemsOptionalContextSplit,
88                ),
89                input_merge: functions::expression::Expression::Special(
90                    functions::expression::Special::InputItemsOptionalContextMerge,
91                ),
92            },
93        }
94    }
95}
96
97#[derive(
98    Debug,
99    Clone,
100    PartialEq,
101    Serialize,
102    Deserialize,
103    JsonSchema,
104    arbitrary::Arbitrary,
105)]
106#[serde(tag = "type")]
107#[schemars(rename = "functions.alpha_vector.InlineFunction")]
108pub enum InlineFunction {
109    #[schemars(title = "Branch")]
110    #[serde(rename = "alpha.vector.branch.function")]
111    Branch {
112        tasks: Vec<super::BranchTaskExpression>,
113    },
114    #[schemars(title = "Leaf")]
115    #[serde(rename = "alpha.vector.leaf.function")]
116    Leaf {
117        tasks: Vec<super::LeafTaskExpression>,
118    },
119}
120
121impl InlineFunction {
122    pub fn tasks(&self) -> &[super::BranchTaskExpression] {
123        match self {
124            InlineFunction::Branch { tasks, .. } => tasks,
125            InlineFunction::Leaf { .. } => &[],
126        }
127    }
128
129    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
130        self.tasks().iter().filter_map(|task| match task {
131            super::BranchTaskExpression::ScalarFunction(t) => Some(&t.path),
132            super::BranchTaskExpression::VectorFunction(t) => Some(&t.path),
133            _ => None,
134        })
135    }
136
137    pub fn transpile(self) -> functions::InlineFunction {
138        match self {
139            InlineFunction::Branch { tasks } => {
140                functions::InlineFunction::Vector {
141                    tasks: tasks
142                        .into_iter()
143                        .map(super::BranchTaskExpression::transpile)
144                        .collect(),
145                    input_split: Some(functions::expression::Expression::Special(
146                        functions::expression::Special::InputItemsOptionalContextSplit,
147                    )),
148                    input_merge: Some(functions::expression::Expression::Special(
149                        functions::expression::Special::InputItemsOptionalContextMerge,
150                    )),
151                }
152            }
153            InlineFunction::Leaf { tasks } => {
154                functions::InlineFunction::Vector {
155                    tasks: tasks
156                        .into_iter()
157                        .map(super::LeafTaskExpression::transpile)
158                        .collect(),
159                    input_split: Some(functions::expression::Expression::Special(
160                        functions::expression::Special::InputItemsOptionalContextSplit,
161                    )),
162                    input_merge: Some(functions::expression::Expression::Special(
163                        functions::expression::Special::InputItemsOptionalContextMerge,
164                    )),
165                }
166            }
167        }
168    }
169}