Skip to main content

objectiveai_sdk/functions/alpha_vector/
function.rs

1use crate::functions;
2use serde::{Deserialize, Serialize};
3use schemars::JsonSchema;
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
6#[serde(tag = "type")]
7#[schemars(rename = "functions.alpha_vector.RemoteFunction")]
8pub enum RemoteFunction {
9    #[schemars(title = "Branch")]
10    #[serde(rename = "alpha.vector.branch.function")]
11    Branch {
12        description: String,
13        input_schema: super::expression::VectorFunctionInputSchema,
14        tasks: Vec<super::BranchTaskExpression>,
15    },
16    #[schemars(title = "Leaf")]
17    #[serde(rename = "alpha.vector.leaf.function")]
18    Leaf {
19        description: String,
20        input_schema: super::expression::VectorFunctionInputSchema,
21        tasks: Vec<super::LeafTaskExpression>,
22    },
23}
24
25impl RemoteFunction {
26    pub fn tasks(&self) -> &[super::BranchTaskExpression] {
27        match self {
28            RemoteFunction::Branch { tasks, .. } => tasks,
29            RemoteFunction::Leaf { .. } => &[],
30        }
31    }
32
33    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
34        self.tasks().iter().filter_map(|task| match task {
35            super::BranchTaskExpression::ScalarFunction(t) => Some(&t.path),
36            super::BranchTaskExpression::VectorFunction(t) => Some(&t.path),
37            _ => None,
38        })
39    }
40
41    pub fn transpile(self) -> functions::RemoteFunction {
42        match self {
43            RemoteFunction::Branch {
44                description,
45                input_schema,
46                tasks,
47            } => functions::RemoteFunction::Vector {
48                description,
49                input_schema: input_schema.transpile(),
50                tasks: tasks
51                    .into_iter()
52                    .map(super::BranchTaskExpression::transpile)
53                    .collect(),
54                output_length: functions::expression::Expression::Special(
55                    functions::expression::Special::InputItemsOutputLength,
56                ),
57                input_split: functions::expression::Expression::Special(
58                    functions::expression::Special::InputItemsOptionalContextSplit,
59                ),
60                input_merge: functions::expression::Expression::Special(
61                    functions::expression::Special::InputItemsOptionalContextMerge,
62                ),
63            },
64            RemoteFunction::Leaf {
65                description,
66                input_schema,
67                tasks,
68            } => functions::RemoteFunction::Vector {
69                description,
70                input_schema: input_schema.transpile(),
71                tasks: tasks
72                    .into_iter()
73                    .map(super::LeafTaskExpression::transpile)
74                    .collect(),
75                output_length: functions::expression::Expression::Special(
76                    functions::expression::Special::InputItemsOutputLength,
77                ),
78                input_split: functions::expression::Expression::Special(
79                    functions::expression::Special::InputItemsOptionalContextSplit,
80                ),
81                input_merge: functions::expression::Expression::Special(
82                    functions::expression::Special::InputItemsOptionalContextMerge,
83                ),
84            },
85        }
86    }
87}
88
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
90#[serde(tag = "type")]
91#[schemars(rename = "functions.alpha_vector.InlineFunction")]
92pub enum InlineFunction {
93    #[schemars(title = "Branch")]
94    #[serde(rename = "alpha.vector.branch.function")]
95    Branch {
96        tasks: Vec<super::BranchTaskExpression>,
97    },
98    #[schemars(title = "Leaf")]
99    #[serde(rename = "alpha.vector.leaf.function")]
100    Leaf {
101        tasks: Vec<super::LeafTaskExpression>,
102    },
103}
104
105impl InlineFunction {
106    pub fn tasks(&self) -> &[super::BranchTaskExpression] {
107        match self {
108            InlineFunction::Branch { tasks, .. } => tasks,
109            InlineFunction::Leaf { .. } => &[],
110        }
111    }
112
113    pub fn remotes(&self) -> impl Iterator<Item = &crate::RemotePath> {
114        self.tasks().iter().filter_map(|task| match task {
115            super::BranchTaskExpression::ScalarFunction(t) => Some(&t.path),
116            super::BranchTaskExpression::VectorFunction(t) => Some(&t.path),
117            _ => None,
118        })
119    }
120
121    pub fn transpile(self) -> functions::InlineFunction {
122        match self {
123            InlineFunction::Branch { tasks } => {
124                functions::InlineFunction::Vector {
125                    tasks: tasks
126                        .into_iter()
127                        .map(super::BranchTaskExpression::transpile)
128                        .collect(),
129                    input_split: Some(functions::expression::Expression::Special(
130                        functions::expression::Special::InputItemsOptionalContextSplit,
131                    )),
132                    input_merge: Some(functions::expression::Expression::Special(
133                        functions::expression::Special::InputItemsOptionalContextMerge,
134                    )),
135                }
136            }
137            InlineFunction::Leaf { tasks } => {
138                functions::InlineFunction::Vector {
139                    tasks: tasks
140                        .into_iter()
141                        .map(super::LeafTaskExpression::transpile)
142                        .collect(),
143                    input_split: Some(functions::expression::Expression::Special(
144                        functions::expression::Special::InputItemsOptionalContextSplit,
145                    )),
146                    input_merge: Some(functions::expression::Expression::Special(
147                        functions::expression::Special::InputItemsOptionalContextMerge,
148                    )),
149                }
150            }
151        }
152    }
153}