Skip to main content

objectiveai_api/functions/
flat_task_profile.rs

1//! Flattened execution data for function execution.
2//!
3//! Transforms nested Function and Profile definitions into a flat structure suitable
4//! for parallel execution. A Function defines the task structure and expressions,
5//! while a Profile provides the weights for each task. This module combines both
6//! into flattened executable tasks.
7
8use crate::ctx;
9use futures::FutureExt;
10use std::{
11    pin::Pin,
12    sync::{Arc, LazyLock},
13    task::Poll,
14};
15
16/// A flattened task ready for execution.
17///
18/// Combines Function structure with Profile weights into an executable node.
19/// Can be a function (with nested tasks), a mapped array of functions, a vector
20/// completion, or a mapped array of vector completions.
21#[derive(Debug, Clone)]
22pub enum FlatTaskProfile {
23    /// A single function task with nested tasks.
24    Function(FunctionFlatTaskProfile),
25    /// Multiple function tasks from a mapped expression.
26    MapFunction(MapFunctionFlatTaskProfile),
27    /// A single vector completion task.
28    VectorCompletion(VectorCompletionFlatTaskProfile),
29    /// Multiple vector completion tasks from a mapped expression.
30    MapVectorCompletion(MapVectorCompletionFlatTaskProfile),
31}
32
33impl FlatTaskProfile {
34    /// Returns an iterator over all vector completion tasks.
35    ///
36    /// Recursively traverses function tasks to collect all leaf vector completions.
37    pub fn vector_completion_ftps(
38        &self,
39    ) -> impl Iterator<Item = &VectorCompletionFlatTaskProfile> {
40        enum Iter<'a> {
41            Function(
42                Box<
43                    dyn Iterator<Item = &'a VectorCompletionFlatTaskProfile>
44                        + 'a,
45                >,
46            ),
47            MapFunction(
48                Box<
49                    dyn Iterator<Item = &'a VectorCompletionFlatTaskProfile>
50                        + 'a,
51                >,
52            ),
53            VectorCompletion(Option<&'a VectorCompletionFlatTaskProfile>),
54            MapVectorCompletion(
55                std::slice::Iter<'a, VectorCompletionFlatTaskProfile>,
56            ),
57        }
58        impl<'a> Iterator for Iter<'a> {
59            type Item = &'a VectorCompletionFlatTaskProfile;
60            fn next(&mut self) -> Option<Self::Item> {
61                match self {
62                    Iter::Function(iter) => iter.next(),
63                    Iter::MapFunction(iter) => iter.next(),
64                    Iter::VectorCompletion(opt) => opt.take(),
65                    Iter::MapVectorCompletion(iter) => iter.next(),
66                }
67            }
68        }
69        match self {
70            FlatTaskProfile::Function(function) => Iter::Function(Box::new(
71                function
72                    .tasks
73                    .iter()
74                    .filter_map(|task| task.as_ref())
75                    .flat_map(|task| task.vector_completion_ftps()),
76            )),
77            FlatTaskProfile::MapFunction(functions) => {
78                Iter::MapFunction(Box::new(
79                    functions
80                        .functions
81                        .iter()
82                        .flat_map(|function| function.tasks.iter())
83                        .filter_map(|task| task.as_ref())
84                        .flat_map(|task| task.vector_completion_ftps()),
85                ))
86            }
87            FlatTaskProfile::VectorCompletion(vector) => {
88                Iter::VectorCompletion(Some(vector))
89            }
90            FlatTaskProfile::MapVectorCompletion(vectors) => {
91                Iter::MapVectorCompletion(vectors.vector_completions.iter())
92            }
93        }
94    }
95    /// Returns the total number of leaf tasks (vector completions).
96    pub fn len(&self) -> usize {
97        match self {
98            FlatTaskProfile::Function(function) => function.len(),
99            FlatTaskProfile::MapFunction(functions) => functions.len(),
100            FlatTaskProfile::VectorCompletion(vector) => vector.len(),
101            FlatTaskProfile::MapVectorCompletion(vectors) => vectors.len(),
102        }
103    }
104
105    /// Returns the number of task indices needed for output assembly.
106    pub fn task_index_len(&self) -> usize {
107        match self {
108            FlatTaskProfile::Function(function) => function.task_index_len(),
109            FlatTaskProfile::MapFunction(functions) => {
110                functions.task_index_len()
111            }
112            FlatTaskProfile::VectorCompletion(vector) => {
113                vector.task_index_len()
114            }
115            FlatTaskProfile::MapVectorCompletion(vectors) => {
116                vectors.task_index_len()
117            }
118        }
119    }
120}
121
122/// Multiple function tasks from a mapped expression.
123///
124/// Created when a task has a `map` index pointing to an input_maps sub-array.
125/// Each element in that array produces one function instance.
126#[derive(Debug, Clone)]
127pub struct MapFunctionFlatTaskProfile {
128    /// Path to this task in the Function tree (indices into tasks arrays).
129    pub path: Vec<u64>,
130    /// The individual flattened function tasks, one per element in the mapped array.
131    pub functions: Vec<FunctionFlatTaskProfile>,
132}
133
134impl MapFunctionFlatTaskProfile {
135    pub fn len(&self) -> usize {
136        self.functions
137            .iter()
138            .map(FunctionFlatTaskProfile::len)
139            .sum()
140    }
141
142    pub fn task_index_len(&self) -> usize {
143        self.functions
144            .iter()
145            .map(FunctionFlatTaskProfile::task_index_len)
146            .sum::<usize>()
147            .max(1)
148    }
149}
150
151/// A flattened function task ready for execution.
152///
153/// Combines a Function definition with its corresponding Profile. Contains the
154/// compiled input, nested tasks with their weights, and the output expression.
155#[derive(Debug, Clone)]
156pub struct FunctionFlatTaskProfile {
157    /// Path to this task in the Function tree (indices into tasks arrays).
158    pub path: Vec<u64>,
159    /// Full Function ID (owner, repository, commit) if from GitHub.
160    pub full_function_id: Option<(String, String, String)>,
161    /// Full Profile ID (owner, repository, commit) if from GitHub.
162    pub full_profile_id: Option<(String, String, String)>,
163    /// Description from the Function definition.
164    pub description: Option<String>,
165    /// The compiled input for this Function.
166    pub input: objectiveai::functions::expression::Input,
167    /// The flattened child tasks (None if task was skipped).
168    pub tasks: Vec<Option<FlatTaskProfile>>,
169    /// The output expression for computing the final score.
170    pub output: objectiveai::functions::expression::Expression,
171    /// The Function type (scalar or vector).
172    pub r#type: FunctionType,
173}
174
175impl FunctionFlatTaskProfile {
176    pub fn len(&self) -> usize {
177        self.tasks
178            .iter()
179            .map(|task| task.as_ref().map_or(1, |task| task.len()))
180            .sum()
181    }
182
183    pub fn task_index_len(&self) -> usize {
184        let mut len = 0;
185        for task in &self.tasks {
186            len += if let Some(task) = task {
187                task.task_index_len()
188            } else {
189                1
190            };
191        }
192        len
193    }
194
195    pub fn task_indices(&self) -> Vec<u64> {
196        let mut indices = Vec::with_capacity(self.tasks.len());
197        let mut current_index = 0u64;
198        for task in &self.tasks {
199            indices.push(current_index);
200            current_index += if let Some(task) = task {
201                task.task_index_len()
202            } else {
203                1
204            } as u64;
205        }
206        indices
207    }
208}
209
210/// The type of a Function's output.
211#[derive(Debug, Clone)]
212pub enum FunctionType {
213    /// Produces a single score in [0, 1].
214    Scalar,
215    /// Produces a vector of scores that sums to ~1.
216    Vector {
217        /// Expected output length, if known from output_length expression.
218        output_length: Option<u64>,
219        /// input_split expression if defined
220        input_split: Option<
221            objectiveai::functions::expression::WithExpression<
222                Vec<objectiveai::functions::expression::Input>,
223            >,
224        >,
225        /// input_merge expression if defined
226        input_merge: Option<
227            objectiveai::functions::expression::WithExpression<
228                objectiveai::functions::expression::Input,
229            >,
230        >,
231    },
232}
233
234/// Multiple vector completion tasks from a mapped expression.
235///
236/// Created when a vector completion task has a `map` index. Each element in the
237/// mapped array produces one vector completion instance.
238#[derive(Debug, Clone)]
239pub struct MapVectorCompletionFlatTaskProfile {
240    /// Path to this task in the Function tree (indices into tasks arrays).
241    pub path: Vec<u64>,
242    /// The individual flattened vector completion tasks.
243    pub vector_completions: Vec<VectorCompletionFlatTaskProfile>,
244}
245
246impl MapVectorCompletionFlatTaskProfile {
247    pub fn len(&self) -> usize {
248        self.vector_completions.len()
249    }
250
251    pub fn task_index_len(&self) -> usize {
252        self.vector_completions.len().max(1)
253    }
254}
255
256/// A flattened vector completion task ready for execution.
257///
258/// The leaf task type. Contains everything needed to run a vector completion:
259/// the Ensemble of LLMs, their weights from the Profile, and the compiled
260/// messages/responses.
261#[derive(Debug, Clone)]
262pub struct VectorCompletionFlatTaskProfile {
263    /// Path to this task in the Function tree (indices into tasks arrays).
264    pub path: Vec<u64>,
265    /// The Ensemble configuration with LLMs and their settings.
266    pub ensemble: objectiveai::ensemble::EnsembleBase,
267    /// The weights for each LLM in the Ensemble (from the Profile).
268    pub profile: Vec<rust_decimal::Decimal>,
269    /// The compiled messages for the vector completion.
270    pub messages: Vec<objectiveai::chat::completions::request::Message>,
271    /// Optional tools for the vector completion (read-only context).
272    pub tools: Option<Vec<objectiveai::chat::completions::request::Tool>>,
273    /// The compiled response options the LLMs will vote on.
274    pub responses: Vec<objectiveai::chat::completions::request::RichContent>,
275}
276
277impl VectorCompletionFlatTaskProfile {
278    pub fn len(&self) -> usize {
279        1
280    }
281
282    pub fn task_index_len(&self) -> usize {
283        1
284    }
285}
286
287/// Parameter for specifying a function source.
288#[derive(Debug, Clone)]
289pub enum FunctionParam {
290    /// Function to fetch from GitHub by owner/repository/commit.
291    Remote {
292        owner: String,
293        repository: String,
294        commit: Option<String>,
295    },
296    /// Already-fetched or inline function definition.
297    FetchedOrInline {
298        full_id: Option<(String, String, String)>,
299        function: objectiveai::functions::Function,
300    },
301}
302
303/// Parameter for specifying a profile source.
304#[derive(Debug, Clone)]
305pub enum ProfileParam {
306    /// Profile to fetch from GitHub by owner/repository/commit.
307    Remote {
308        owner: String,
309        repository: String,
310        commit: Option<String>,
311    },
312    /// Already-fetched or inline profile definition.
313    FetchedOrInline {
314        full_id: Option<(String, String, String)>,
315        profile: objectiveai::functions::Profile,
316    },
317}
318
319/// Recursively builds a flattened task from a Function and Profile.
320///
321/// Fetches any remote Functions/Profiles/Ensembles, compiles task expressions
322/// with the input, and validates that the Profile structure matches the Function.
323/// The result is a flat tree of tasks ready for parallel execution.
324pub async fn get_flat_task_profile<CTXEXT>(
325    ctx: ctx::Context<CTXEXT>,
326    mut path: Vec<u64>,
327    function: FunctionParam,
328    profile: ProfileParam,
329    input: objectiveai::functions::expression::Input,
330    function_fetcher: Arc<
331        impl super::function_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
332    >,
333    profile_fetcher: Arc<
334        impl super::profile_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
335    >,
336    ensemble_fetcher: Arc<
337        crate::ensemble::fetcher::CachingFetcher<
338            CTXEXT,
339            impl crate::ensemble::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
340        >,
341    >,
342) -> Result<super::FunctionFlatTaskProfile, super::executions::Error>
343where
344    CTXEXT: Send + Sync + 'static,
345{
346    static EMPTY_TASKS: LazyLock<
347        Vec<Option<objectiveai::functions::expression::TaskOutput>>,
348    > = LazyLock::new(|| Vec::new());
349
350    // fetch function and profile if needed
351    let (function_full_id, function, profile_full_id, profile): (
352        Option<(String, String, String)>,
353        objectiveai::functions::Function,
354        Option<(String, String, String)>,
355        objectiveai::functions::Profile,
356    ) = match (function, profile) {
357        (
358            FunctionParam::Remote {
359                owner: fowner,
360                repository: frepository,
361                commit: fcommit,
362            },
363            ProfileParam::Remote {
364                owner: powner,
365                repository: prepository,
366                commit: pcommit,
367            },
368        ) => {
369            let ((function, fcommit), (profile, pcommit)) = tokio::try_join!(
370                function_fetcher
371                    .fetch(
372                        ctx.clone(),
373                        &fowner,
374                        &frepository,
375                        fcommit.as_deref()
376                    )
377                    .map(|result| match result {
378                        Ok(Some(function)) => {
379                            Ok((function.inner, function.commit))
380                        }
381                        Ok(_) =>
382                            Err(super::executions::Error::FunctionNotFound),
383                        Err(e) =>
384                            Err(super::executions::Error::FetchFunction(e)),
385                    }),
386                profile_fetcher
387                    .fetch(
388                        ctx.clone(),
389                        &powner,
390                        &prepository,
391                        pcommit.as_deref(),
392                    )
393                    .map(|result| match result {
394                        Ok(Some(profile)) => {
395                            Ok((profile.inner, profile.commit))
396                        }
397                        Ok(_) => Err(super::executions::Error::ProfileNotFound),
398                        Err(e) =>
399                            Err(super::executions::Error::FetchProfile(e)),
400                    }),
401            )?;
402            (
403                Some((fowner.to_owned(), frepository.to_owned(), fcommit)),
404                objectiveai::functions::Function::Remote(function),
405                Some((powner, prepository, pcommit)),
406                objectiveai::functions::Profile::Remote(profile),
407            )
408        }
409        (
410            FunctionParam::Remote {
411                owner: fowner,
412                repository: frepository,
413                commit: fcommit,
414            },
415            ProfileParam::FetchedOrInline {
416                full_id: pfull_id,
417                profile,
418            },
419        ) => {
420            let (function, fcommit) = match function_fetcher
421                .fetch(ctx.clone(), &fowner, &frepository, fcommit.as_deref())
422                .await
423            {
424                Ok(Some(function)) => Ok((function.inner, function.commit)),
425                Ok(_) => Err(super::executions::Error::FunctionNotFound),
426                Err(e) => Err(super::executions::Error::FetchFunction(e)),
427            }?;
428            (
429                Some((fowner, frepository, fcommit)),
430                objectiveai::functions::Function::Remote(function),
431                pfull_id,
432                profile,
433            )
434        }
435        (
436            FunctionParam::FetchedOrInline {
437                full_id: ffull_id,
438                function,
439            },
440            ProfileParam::Remote {
441                owner: powner,
442                repository: prepository,
443                commit: pcommit,
444            },
445        ) => {
446            let (profile, pcommit) = match profile_fetcher
447                .fetch(ctx.clone(), &powner, &prepository, pcommit.as_deref())
448                .await
449            {
450                Ok(Some(profile)) => Ok((profile.inner, profile.commit)),
451                Ok(_) => Err(super::executions::Error::ProfileNotFound),
452                Err(e) => Err(super::executions::Error::FetchProfile(e)),
453            }?;
454            (
455                ffull_id,
456                function,
457                Some((powner, prepository, pcommit)),
458                objectiveai::functions::Profile::Remote(profile),
459            )
460        }
461        (
462            FunctionParam::FetchedOrInline {
463                full_id: ffull_id,
464                function,
465            },
466            ProfileParam::FetchedOrInline {
467                full_id: pfull_id,
468                profile,
469            },
470        ) => (ffull_id, function, pfull_id, profile),
471    };
472
473    // validate input against input_schema
474    if let Some(input_schema) = function.input_schema() {
475        if !input_schema.validate_input(&input) {
476            return Err(super::executions::Error::InputSchemaMismatch);
477        }
478    }
479
480    // validate profile length
481    if match &profile {
482        objectiveai::functions::Profile::Remote(rp) => rp.tasks.len(),
483        objectiveai::functions::Profile::Inline(ip) => ip.tasks.len(),
484    } != function.tasks().len()
485    {
486        return Err(super::executions::Error::InvalidProfile);
487    }
488
489    // take description
490    let description = function.description().map(str::to_owned);
491
492    // take output
493    let output = function.output().clone();
494
495    // take type, compile output_length if needed
496    let r#type = match function {
497        objectiveai::functions::Function::Remote(
498            objectiveai::functions::RemoteFunction::Scalar { .. },
499        ) => FunctionType::Scalar,
500        objectiveai::functions::Function::Remote(
501            objectiveai::functions::RemoteFunction::Vector {
502                ref output_length,
503                ref input_split,
504                ref input_merge,
505                ..
506            },
507        ) => {
508            let params = objectiveai::functions::expression::Params::Ref(
509                objectiveai::functions::expression::ParamsRef {
510                    input: &input,
511                    tasks: &EMPTY_TASKS,
512                    map: None,
513                },
514            );
515            FunctionType::Vector {
516                output_length: Some(
517                    output_length.clone().compile_one(&params)?,
518                ),
519                input_split: Some(input_split.clone()),
520                input_merge: Some(input_merge.clone()),
521            }
522        }
523        objectiveai::functions::Function::Inline(
524            objectiveai::functions::InlineFunction::Scalar { .. },
525        ) => FunctionType::Scalar,
526        objectiveai::functions::Function::Inline(
527            objectiveai::functions::InlineFunction::Vector {
528                ref input_split,
529                ref input_merge,
530                ..
531            },
532        ) => FunctionType::Vector {
533            output_length: None,
534            input_split: input_split.clone(),
535            input_merge: input_merge.clone(),
536        },
537    };
538
539    // compile function tasks
540    let tasks = function.compile_tasks(&input)?;
541
542    // initialize flat tasks / futs vector
543    let mut flat_tasks_or_futs = Vec::with_capacity(tasks.len());
544
545    // iterate through tasks
546    for (i, (task, profile)) in tasks
547        .into_iter()
548        .zip(match profile {
549            objectiveai::functions::Profile::Remote(rp) => {
550                either::Either::Left(rp.tasks.into_iter())
551            }
552            objectiveai::functions::Profile::Inline(ip) => {
553                either::Either::Right(ip.tasks.into_iter())
554            }
555        })
556        .enumerate()
557    {
558        // if skip, push None to flat tasks
559        let task = match task {
560            Some(task) => task,
561            None => {
562                flat_tasks_or_futs.push(TaskFut::SkipTask);
563                continue;
564            }
565        };
566
567        // task path
568        let task_path = {
569            path.push(i as u64);
570            let p = path.clone();
571            path.pop();
572            p
573        };
574
575        // switch by task type
576        match task {
577            objectiveai::functions::CompiledTask::One(
578                objectiveai::functions::Task::ScalarFunction(
579                    objectiveai::functions::ScalarFunctionTask {
580                        owner,
581                        repository,
582                        commit,
583                        input,
584                    },
585                )
586                | objectiveai::functions::Task::VectorFunction(
587                    objectiveai::functions::VectorFunctionTask {
588                        owner,
589                        repository,
590                        commit,
591                        input,
592                    },
593                ),
594            ) => {
595                flat_tasks_or_futs.push(TaskFut::FunctionTaskFut(Box::pin(
596                    get_flat_task_profile(
597                        ctx.clone(),
598                        task_path,
599                        FunctionParam::Remote {
600                            owner,
601                            repository,
602                            commit: Some(commit),
603                        },
604                        match profile {
605                            objectiveai::functions::TaskProfile::RemoteFunction {
606                                owner,
607                                repository,
608                                commit,
609                            } => ProfileParam::Remote {
610                                owner,
611                                repository,
612                                commit,
613                            },
614                            objectiveai::functions::TaskProfile::InlineFunction(
615                                profile,
616                            ) => ProfileParam::FetchedOrInline {
617                                full_id: None,
618                                profile: objectiveai::functions::Profile::Inline(
619                                    profile,
620                                ),
621                            },
622                            _ => return Err(super::executions::Error::InvalidProfile),
623                        },
624                        input,
625                        function_fetcher.clone(),
626                        profile_fetcher.clone(),
627                        ensemble_fetcher.clone(),
628                    )
629                )));
630            }
631            objectiveai::functions::CompiledTask::One(
632                objectiveai::functions::Task::VectorCompletion(task),
633            ) => {
634                let (ensemble, profile) = match profile {
635                    objectiveai::functions::TaskProfile::VectorCompletion {
636                        ensemble,
637                        profile,
638                    } => (ensemble, profile),
639                    _ => return Err(super::executions::Error::InvalidProfile),
640                };
641                flat_tasks_or_futs.push(TaskFut::VectorTaskFut(Box::pin(
642                    get_vector_completion_flat_task_profile(
643                        ctx.clone(),
644                        task_path,
645                        task,
646                        ensemble,
647                        profile,
648                        ensemble_fetcher.clone(),
649                    ),
650                )));
651            }
652            objectiveai::functions::CompiledTask::Many(tasks)
653                if tasks.len() == 0 =>
654            {
655                flat_tasks_or_futs.push(TaskFut::Task(Some(
656                    FlatTaskProfile::MapVectorCompletion(
657                        MapVectorCompletionFlatTaskProfile {
658                            path: task_path,
659                            vector_completions: Vec::new(),
660                        },
661                    ),
662                )));
663            }
664            objectiveai::functions::CompiledTask::Many(tasks) => {
665                let vector_completions = match &tasks[0] {
666                    objectiveai::functions::Task::VectorCompletion(_) => true,
667                    _ => false,
668                };
669                if vector_completions {
670                    let mut futs = Vec::with_capacity(tasks.len());
671                    for (j, task) in tasks.into_iter().enumerate() {
672                        let mut task_path = task_path.clone();
673                        task_path.push(j as u64);
674                        let (ensemble, profile) = match &profile {
675                            objectiveai::functions::TaskProfile::VectorCompletion {
676                                ensemble,
677                                profile,
678                            } => (ensemble.clone(), profile.clone()),
679                            _ => return Err(super::executions::Error::InvalidProfile),
680                        };
681                        futs.push(get_vector_completion_flat_task_profile(
682                            ctx.clone(),
683                            task_path,
684                            match task {
685                                objectiveai::functions::Task::VectorCompletion(
686                                    vc_task,
687                                ) => vc_task,
688                                _ => unreachable!(),
689                            },
690                            ensemble,
691                            profile,
692                            ensemble_fetcher.clone(),
693                        ));
694                    }
695                    flat_tasks_or_futs.push(TaskFut::MapVectorTaskFut((
696                        task_path,
697                        futures::future::try_join_all(futs),
698                    )));
699                } else {
700                    let mut futs = Vec::with_capacity(tasks.len());
701                    for (j, task) in tasks.into_iter().enumerate() {
702                        let mut task_path = task_path.clone();
703                        task_path.push(j as u64);
704                        futs.push(get_flat_task_profile(
705                            ctx.clone(),
706                            task_path,
707                            FunctionParam::Remote {
708                                owner: match &task {
709                                    objectiveai::functions::Task::ScalarFunction(
710                                        sf_task,
711                                    ) => sf_task.owner.clone(),
712                                    objectiveai::functions::Task::VectorFunction(
713                                        vf_task,
714                                    ) => vf_task.owner.clone(),
715                                    _ => unreachable!(),
716                                },
717                                repository: match &task {
718                                    objectiveai::functions::Task::ScalarFunction(
719                                        sf_task,
720                                    ) => sf_task.repository.clone(),
721                                    objectiveai::functions::Task::VectorFunction(
722                                        vf_task,
723                                    ) => vf_task.repository.clone(),
724                                    _ => unreachable!(),
725                                },
726                                commit: Some(match &task {
727                                    objectiveai::functions::Task::ScalarFunction(
728                                        sf_task,
729                                    ) => sf_task.commit.clone(),
730                                    objectiveai::functions::Task::VectorFunction(
731                                        vf_task,
732                                    ) => vf_task.commit.clone(),
733                                    _ => unreachable!(),
734                                }),
735                            },
736                            match &profile {
737                                objectiveai::functions::TaskProfile::RemoteFunction {
738                                    owner,
739                                    repository,
740                                    commit,
741                                } => ProfileParam::Remote {
742                                    owner: owner.clone(),
743                                    repository: repository.clone(),
744                                    commit: commit.clone(),
745                                },
746                                objectiveai::functions::TaskProfile::InlineFunction(
747                                    profile,
748                                ) => ProfileParam::FetchedOrInline {
749                                    full_id: None,
750                                    profile: objectiveai::functions::Profile::Inline(
751                                        profile.clone(),
752                                    ),
753                                },
754                                _ => return Err(super::executions::Error::InvalidProfile),
755                            },
756                            match &task {
757                                objectiveai::functions::Task::ScalarFunction(
758                                    sf_task,
759                                ) => sf_task.input.clone(),
760                                objectiveai::functions::Task::VectorFunction(
761                                    vf_task,
762                                ) => vf_task.input.clone(),
763                                _ => unreachable!(),
764                            },
765                            function_fetcher.clone(),
766                            profile_fetcher.clone(),
767                            ensemble_fetcher.clone(),
768                        ));
769                    }
770                    flat_tasks_or_futs.push(TaskFut::MapFunctionTaskFut((
771                        task_path,
772                        futures::future::try_join_all(futs),
773                    )));
774                }
775            }
776        }
777    }
778
779    // await all futs
780    let tasks = futures::future::try_join_all(flat_tasks_or_futs).await?;
781
782    // return flat function task
783    Ok(super::FunctionFlatTaskProfile {
784        path,
785        description,
786        full_function_id: function_full_id,
787        full_profile_id: profile_full_id,
788        input,
789        tasks,
790        output,
791        r#type,
792    })
793}
794
795async fn get_vector_completion_flat_task_profile<CTXEXT>(
796    ctx: ctx::Context<CTXEXT>,
797    path: Vec<u64>,
798    task: objectiveai::functions::VectorCompletionTask,
799    ensemble: objectiveai::vector::completions::request::Ensemble,
800    profile: Vec<rust_decimal::Decimal>,
801    ensemble_fetcher: Arc<
802        crate::ensemble::fetcher::CachingFetcher<
803            CTXEXT,
804            impl crate::ensemble::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
805        >,
806    >,
807) -> Result<super::VectorCompletionFlatTaskProfile, super::executions::Error>
808where
809    CTXEXT: Send + Sync + 'static,
810{
811    // switch based on profile
812    let ensemble = match ensemble {
813        objectiveai::vector::completions::request::Ensemble::Id(id) => {
814            // fetch ensemble
815            ensemble_fetcher
816                .fetch(ctx, &id)
817                .map(|result| match result {
818                    Ok(Some((ensemble, _))) => Ok(ensemble),
819                    Ok(None) => Err(super::executions::Error::EnsembleNotFound),
820                    Err(e) => Err(super::executions::Error::FetchEnsemble(e)),
821                })
822                .await?
823        }
824        objectiveai::vector::completions::request::Ensemble::Provided(
825            ensemble,
826        ) => {
827            // validate ensemble
828            ensemble
829                .clone()
830                .try_into()
831                .map_err(super::executions::Error::InvalidEnsemble)?
832        }
833    };
834
835    // validate profile length
836    if profile.len() != ensemble.llms.len() {
837        return Err(super::executions::Error::InvalidProfile);
838    }
839
840    // construct flat task profile
841    Ok(super::VectorCompletionFlatTaskProfile {
842        path,
843        ensemble: objectiveai::ensemble::EnsembleBase {
844            llms: ensemble
845                .llms
846                .into_iter()
847                .map(|llm| {
848                    objectiveai::ensemble_llm::EnsembleLlmBaseWithFallbacksAndCount {
849                        count: llm.count,
850                        inner: llm.inner.base,
851                        fallbacks: llm.fallbacks.map(|fallbacks| {
852                            fallbacks
853                                .into_iter()
854                                .map(|fallback| fallback.base)
855                                .collect()
856                        }),
857                    }
858                })
859                .collect(),
860        },
861        profile,
862        messages: task.messages,
863        tools: task.tools,
864        responses: task.responses,
865    })
866}
867
868enum TaskFut<
869    VFUT: Future<
870        Output = Result<
871            super::VectorCompletionFlatTaskProfile,
872            super::executions::Error,
873        >,
874    >,
875    FFUT: Future<
876        Output = Result<
877            super::FunctionFlatTaskProfile,
878            super::executions::Error,
879        >,
880    >,
881> {
882    SkipTask,
883    Task(Option<super::FlatTaskProfile>),
884    VectorTaskFut(Pin<Box<VFUT>>),
885    MapVectorTaskFut((Vec<u64>, futures::future::TryJoinAll<VFUT>)),
886    FunctionTaskFut(Pin<Box<FFUT>>),
887    MapFunctionTaskFut((Vec<u64>, futures::future::TryJoinAll<FFUT>)),
888}
889
890impl<VFUT, FFUT> Future for TaskFut<VFUT, FFUT>
891where
892    VFUT: Future<
893        Output = Result<
894            super::VectorCompletionFlatTaskProfile,
895            super::executions::Error,
896        >,
897    >,
898    FFUT: Future<
899        Output = Result<
900            super::FunctionFlatTaskProfile,
901            super::executions::Error,
902        >,
903    >,
904{
905    type Output =
906        Result<Option<super::FlatTaskProfile>, super::executions::Error>;
907    fn poll(
908        self: Pin<&mut Self>,
909        cx: &mut std::task::Context<'_>,
910    ) -> Poll<Self::Output> {
911        match self.get_mut() {
912            TaskFut::SkipTask => Poll::Ready(Ok(None)),
913            TaskFut::Task(task) => Poll::Ready(Ok(task.take())),
914            TaskFut::VectorTaskFut(fut) => Pin::new(fut)
915                .poll(cx)
916                .map_ok(FlatTaskProfile::VectorCompletion)
917                .map_ok(Some),
918            TaskFut::MapVectorTaskFut((path, futs)) => {
919                Pin::new(futs).poll(cx).map_ok(|results| {
920                    Some(FlatTaskProfile::MapVectorCompletion(
921                        MapVectorCompletionFlatTaskProfile {
922                            path: path.clone(),
923                            vector_completions: results,
924                        },
925                    ))
926                })
927            }
928            TaskFut::FunctionTaskFut(fut) => Pin::new(fut)
929                .poll(cx)
930                .map_ok(FlatTaskProfile::Function)
931                .map_ok(Some),
932            TaskFut::MapFunctionTaskFut((path, futs)) => {
933                Pin::new(futs).poll(cx).map_ok(|results| {
934                    Some(FlatTaskProfile::MapFunction(
935                        MapFunctionFlatTaskProfile {
936                            path: path.clone(),
937                            functions: results,
938                        },
939                    ))
940                })
941            }
942        }
943    }
944}