objectiveai_api/functions/
flat_task_profile.rs

1//! Flattened execution data for function execution.
2//!
3//! Transforms nested Function and Profile definitions into a flat structure suitable
4//! for parallel execution. A Function defines the task structure and expressions,
5//! while a Profile provides the weights for each task. This module combines both
6//! into flattened executable tasks.
7
8use crate::ctx;
9use futures::FutureExt;
10use std::{
11    pin::Pin,
12    sync::{Arc, LazyLock},
13    task::Poll,
14};
15
16/// A flattened task ready for execution.
17///
18/// Combines Function structure with Profile weights into an executable node.
19/// Can be a function (with nested tasks), a mapped array of functions, a vector
20/// completion, or a mapped array of vector completions.
21#[derive(Debug, Clone)]
22pub enum FlatTaskProfile {
23    /// A single function task with nested tasks.
24    Function(FunctionFlatTaskProfile),
25    /// Multiple function tasks from a mapped expression.
26    MapFunction(MapFunctionFlatTaskProfile),
27    /// A single vector completion task.
28    VectorCompletion(VectorCompletionFlatTaskProfile),
29    /// Multiple vector completion tasks from a mapped expression.
30    MapVectorCompletion(MapVectorCompletionFlatTaskProfile),
31}
32
33impl FlatTaskProfile {
34    /// Returns an iterator over all vector completion tasks.
35    ///
36    /// Recursively traverses function tasks to collect all leaf vector completions.
37    pub fn vector_completion_ftps(
38        &self,
39    ) -> impl Iterator<Item = &VectorCompletionFlatTaskProfile> {
40        enum Iter<'a> {
41            Function(
42                Box<
43                    dyn Iterator<Item = &'a VectorCompletionFlatTaskProfile>
44                        + 'a,
45                >,
46            ),
47            MapFunction(
48                Box<
49                    dyn Iterator<Item = &'a VectorCompletionFlatTaskProfile>
50                        + 'a,
51                >,
52            ),
53            VectorCompletion(Option<&'a VectorCompletionFlatTaskProfile>),
54            MapVectorCompletion(
55                std::slice::Iter<'a, VectorCompletionFlatTaskProfile>,
56            ),
57        }
58        impl<'a> Iterator for Iter<'a> {
59            type Item = &'a VectorCompletionFlatTaskProfile;
60            fn next(&mut self) -> Option<Self::Item> {
61                match self {
62                    Iter::Function(iter) => iter.next(),
63                    Iter::MapFunction(iter) => iter.next(),
64                    Iter::VectorCompletion(opt) => opt.take(),
65                    Iter::MapVectorCompletion(iter) => iter.next(),
66                }
67            }
68        }
69        match self {
70            FlatTaskProfile::Function(function) => Iter::Function(Box::new(
71                function
72                    .tasks
73                    .iter()
74                    .filter_map(|task| task.as_ref())
75                    .flat_map(|task| task.vector_completion_ftps()),
76            )),
77            FlatTaskProfile::MapFunction(functions) => {
78                Iter::MapFunction(Box::new(
79                    functions
80                        .functions
81                        .iter()
82                        .flat_map(|function| function.tasks.iter())
83                        .filter_map(|task| task.as_ref())
84                        .flat_map(|task| task.vector_completion_ftps()),
85                ))
86            }
87            FlatTaskProfile::VectorCompletion(vector) => {
88                Iter::VectorCompletion(Some(vector))
89            }
90            FlatTaskProfile::MapVectorCompletion(vectors) => {
91                Iter::MapVectorCompletion(vectors.vector_completions.iter())
92            }
93        }
94    }
95    /// Returns the total number of leaf tasks (vector completions).
96    pub fn len(&self) -> usize {
97        match self {
98            FlatTaskProfile::Function(function) => function.len(),
99            FlatTaskProfile::MapFunction(functions) => functions.len(),
100            FlatTaskProfile::VectorCompletion(vector) => vector.len(),
101            FlatTaskProfile::MapVectorCompletion(vectors) => vectors.len(),
102        }
103    }
104
105    /// Returns the number of task indices needed for output assembly.
106    pub fn task_index_len(&self) -> usize {
107        match self {
108            FlatTaskProfile::Function(function) => function.task_index_len(),
109            FlatTaskProfile::MapFunction(functions) => {
110                functions.task_index_len()
111            }
112            FlatTaskProfile::VectorCompletion(vector) => {
113                vector.task_index_len()
114            }
115            FlatTaskProfile::MapVectorCompletion(vectors) => {
116                vectors.task_index_len()
117            }
118        }
119    }
120}
121
122/// Multiple function tasks from a mapped expression.
123///
124/// Created when a task has a `map` index pointing to an input_maps sub-array.
125/// Each element in that array produces one function instance.
126#[derive(Debug, Clone)]
127pub struct MapFunctionFlatTaskProfile {
128    /// Path to this task in the Function tree (indices into tasks arrays).
129    pub path: Vec<u64>,
130    /// The individual flattened function tasks, one per element in the mapped array.
131    pub functions: Vec<FunctionFlatTaskProfile>,
132}
133
134impl MapFunctionFlatTaskProfile {
135    pub fn len(&self) -> usize {
136        self.functions
137            .iter()
138            .map(FunctionFlatTaskProfile::len)
139            .sum()
140    }
141
142    pub fn task_index_len(&self) -> usize {
143        self.functions
144            .iter()
145            .map(FunctionFlatTaskProfile::task_index_len)
146            .sum::<usize>()
147            .max(1)
148    }
149}
150
151/// A flattened function task ready for execution.
152///
153/// Combines a Function definition with its corresponding Profile. Contains the
154/// compiled input, nested tasks with their weights, and the output expression.
155#[derive(Debug, Clone)]
156pub struct FunctionFlatTaskProfile {
157    /// Path to this task in the Function tree (indices into tasks arrays).
158    pub path: Vec<u64>,
159    /// Full Function ID (owner, repository, commit) if from GitHub.
160    pub full_function_id: Option<(String, String, String)>,
161    /// Full Profile ID (owner, repository, commit) if from GitHub.
162    pub full_profile_id: Option<(String, String, String)>,
163    /// Description from the Function definition.
164    pub description: Option<String>,
165    /// The compiled input for this Function.
166    pub input: objectiveai::functions::expression::Input,
167    /// The flattened child tasks (None if task was skipped).
168    pub tasks: Vec<Option<FlatTaskProfile>>,
169    /// The output expression for computing the final score.
170    pub output: objectiveai::functions::expression::Expression,
171    /// The Function type (scalar or vector).
172    pub r#type: FunctionType,
173}
174
175impl FunctionFlatTaskProfile {
176    pub fn len(&self) -> usize {
177        self.tasks
178            .iter()
179            .map(|task| task.as_ref().map_or(1, |task| task.len()))
180            .sum()
181    }
182
183    pub fn task_index_len(&self) -> usize {
184        let mut len = 0;
185        for task in &self.tasks {
186            len += if let Some(task) = task {
187                task.task_index_len()
188            } else {
189                1
190            };
191        }
192        len
193    }
194
195    pub fn task_indices(&self) -> Vec<u64> {
196        let mut indices = Vec::with_capacity(self.tasks.len());
197        let mut current_index = 0u64;
198        for task in &self.tasks {
199            indices.push(current_index);
200            current_index += if let Some(task) = task {
201                task.task_index_len()
202            } else {
203                1
204            } as u64;
205        }
206        indices
207    }
208}
209
210/// The type of a Function's output.
211#[derive(Debug, Clone, Copy)]
212pub enum FunctionType {
213    /// Produces a single score in [0, 1].
214    Scalar,
215    /// Produces a vector of scores that sums to ~1.
216    Vector {
217        /// Expected output length, if known from output_length expression.
218        output_length: Option<u64>,
219    },
220}
221
222/// Multiple vector completion tasks from a mapped expression.
223///
224/// Created when a vector completion task has a `map` index. Each element in the
225/// mapped array produces one vector completion instance.
226#[derive(Debug, Clone)]
227pub struct MapVectorCompletionFlatTaskProfile {
228    /// Path to this task in the Function tree (indices into tasks arrays).
229    pub path: Vec<u64>,
230    /// The individual flattened vector completion tasks.
231    pub vector_completions: Vec<VectorCompletionFlatTaskProfile>,
232}
233
234impl MapVectorCompletionFlatTaskProfile {
235    pub fn len(&self) -> usize {
236        self.vector_completions.len()
237    }
238
239    pub fn task_index_len(&self) -> usize {
240        self.vector_completions.len().max(1)
241    }
242}
243
244/// A flattened vector completion task ready for execution.
245///
246/// The leaf task type. Contains everything needed to run a vector completion:
247/// the Ensemble of LLMs, their weights from the Profile, and the compiled
248/// messages/responses.
249#[derive(Debug, Clone)]
250pub struct VectorCompletionFlatTaskProfile {
251    /// Path to this task in the Function tree (indices into tasks arrays).
252    pub path: Vec<u64>,
253    /// The Ensemble configuration with LLMs and their settings.
254    pub ensemble: objectiveai::ensemble::EnsembleBase,
255    /// The weights for each LLM in the Ensemble (from the Profile).
256    pub profile: Vec<rust_decimal::Decimal>,
257    /// The compiled messages for the vector completion.
258    pub messages: Vec<objectiveai::chat::completions::request::Message>,
259    /// Optional tools for the vector completion (read-only context).
260    pub tools: Option<Vec<objectiveai::chat::completions::request::Tool>>,
261    /// The compiled response options the LLMs will vote on.
262    pub responses: Vec<objectiveai::chat::completions::request::RichContent>,
263}
264
265impl VectorCompletionFlatTaskProfile {
266    pub fn len(&self) -> usize {
267        1
268    }
269
270    pub fn task_index_len(&self) -> usize {
271        1
272    }
273}
274
275/// Parameter for specifying a function source.
276#[derive(Debug, Clone)]
277pub enum FunctionParam {
278    /// Function to fetch from GitHub by owner/repository/commit.
279    Remote {
280        owner: String,
281        repository: String,
282        commit: Option<String>,
283    },
284    /// Already-fetched or inline function definition.
285    FetchedOrInline {
286        full_id: Option<(String, String, String)>,
287        function: objectiveai::functions::Function,
288    },
289}
290
291/// Parameter for specifying a profile source.
292#[derive(Debug, Clone)]
293pub enum ProfileParam {
294    /// Profile to fetch from GitHub by owner/repository/commit.
295    Remote {
296        owner: String,
297        repository: String,
298        commit: Option<String>,
299    },
300    /// Already-fetched or inline profile definition.
301    FetchedOrInline {
302        full_id: Option<(String, String, String)>,
303        profile: objectiveai::functions::Profile,
304    },
305}
306
307/// Recursively builds a flattened task from a Function and Profile.
308///
309/// Fetches any remote Functions/Profiles/Ensembles, compiles task expressions
310/// with the input, and validates that the Profile structure matches the Function.
311/// The result is a flat tree of tasks ready for parallel execution.
312pub async fn get_flat_task_profile<CTXEXT>(
313    ctx: ctx::Context<CTXEXT>,
314    mut path: Vec<u64>,
315    function: FunctionParam,
316    profile: ProfileParam,
317    input: objectiveai::functions::expression::Input,
318    function_fetcher: Arc<
319        impl super::function_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
320    >,
321    profile_fetcher: Arc<
322        impl super::profile_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
323    >,
324    ensemble_fetcher: Arc<
325        crate::ensemble::fetcher::CachingFetcher<
326            CTXEXT,
327            impl crate::ensemble::fetcher::Fetcher<CTXEXT>
328            + Send
329            + Sync
330            + 'static,
331        >,
332    >,
333) -> Result<super::FunctionFlatTaskProfile, super::executions::Error>
334where
335    CTXEXT: Send + Sync + 'static,
336{
337    static EMPTY_TASKS: LazyLock<
338        Vec<Option<objectiveai::functions::expression::TaskOutput>>,
339    > = LazyLock::new(|| Vec::new());
340
341    // fetch function and profile if needed
342    let (function_full_id, function, profile_full_id, profile): (
343        Option<(String, String, String)>,
344        objectiveai::functions::Function,
345        Option<(String, String, String)>,
346        objectiveai::functions::Profile,
347    ) = match (function, profile) {
348        (
349            FunctionParam::Remote {
350                owner: fowner,
351                repository: frepository,
352                commit: fcommit,
353            },
354            ProfileParam::Remote {
355                owner: powner,
356                repository: prepository,
357                commit: pcommit,
358            },
359        ) => {
360            let ((function, fcommit), (profile, pcommit)) = tokio::try_join!(
361                function_fetcher
362                    .fetch(
363                        ctx.clone(),
364                        &fowner,
365                        &frepository,
366                        fcommit.as_deref()
367                    )
368                    .map(|result| match result {
369                        Ok(Some(function)) => {
370                            Ok((function.inner, function.commit))
371                        }
372                        Ok(_) =>
373                            Err(super::executions::Error::FunctionNotFound),
374                        Err(e) =>
375                            Err(super::executions::Error::FetchFunction(e)),
376                    }),
377                profile_fetcher
378                    .fetch(
379                        ctx.clone(),
380                        &powner,
381                        &prepository,
382                        pcommit.as_deref(),
383                    )
384                    .map(|result| match result {
385                        Ok(Some(profile)) => {
386                            Ok((profile.inner, profile.commit))
387                        }
388                        Ok(_) => Err(super::executions::Error::ProfileNotFound),
389                        Err(e) =>
390                            Err(super::executions::Error::FetchProfile(e)),
391                    }),
392            )?;
393            (
394                Some((fowner.to_owned(), frepository.to_owned(), fcommit)),
395                objectiveai::functions::Function::Remote(function),
396                Some((powner, prepository, pcommit)),
397                objectiveai::functions::Profile::Remote(profile),
398            )
399        }
400        (
401            FunctionParam::Remote {
402                owner: fowner,
403                repository: frepository,
404                commit: fcommit,
405            },
406            ProfileParam::FetchedOrInline {
407                full_id: pfull_id,
408                profile,
409            },
410        ) => {
411            let (function, fcommit) = match function_fetcher
412                .fetch(ctx.clone(), &fowner, &frepository, fcommit.as_deref())
413                .await
414            {
415                Ok(Some(function)) => Ok((function.inner, function.commit)),
416                Ok(_) => Err(super::executions::Error::FunctionNotFound),
417                Err(e) => Err(super::executions::Error::FetchFunction(e)),
418            }?;
419            (
420                Some((fowner, frepository, fcommit)),
421                objectiveai::functions::Function::Remote(function),
422                pfull_id,
423                profile,
424            )
425        }
426        (
427            FunctionParam::FetchedOrInline {
428                full_id: ffull_id,
429                function,
430            },
431            ProfileParam::Remote {
432                owner: powner,
433                repository: prepository,
434                commit: pcommit,
435            },
436        ) => {
437            let (profile, pcommit) = match profile_fetcher
438                .fetch(ctx.clone(), &powner, &prepository, pcommit.as_deref())
439                .await
440            {
441                Ok(Some(profile)) => Ok((profile.inner, profile.commit)),
442                Ok(_) => Err(super::executions::Error::ProfileNotFound),
443                Err(e) => Err(super::executions::Error::FetchProfile(e)),
444            }?;
445            (
446                ffull_id,
447                function,
448                Some((powner, prepository, pcommit)),
449                objectiveai::functions::Profile::Remote(profile),
450            )
451        }
452        (
453            FunctionParam::FetchedOrInline {
454                full_id: ffull_id,
455                function,
456            },
457            ProfileParam::FetchedOrInline {
458                full_id: pfull_id,
459                profile,
460            },
461        ) => (ffull_id, function, pfull_id, profile),
462    };
463
464    // validate input against input_schema
465    if let Some(input_schema) = function.input_schema() {
466        if !input_schema.validate_input(&input) {
467            return Err(super::executions::Error::InputSchemaMismatch);
468        }
469    }
470
471    // validate profile length
472    if match &profile {
473        objectiveai::functions::Profile::Remote(rp) => rp.tasks.len(),
474        objectiveai::functions::Profile::Inline(ip) => ip.tasks.len(),
475    } != function.tasks().len()
476    {
477        return Err(super::executions::Error::InvalidProfile);
478    }
479
480    // take description
481    let description = function.description().map(str::to_owned);
482
483    // take output
484    let output = function.output().clone();
485
486    // take type, compile output_length if needed
487    let r#type = match function {
488        objectiveai::functions::Function::Remote(
489            objectiveai::functions::RemoteFunction::Scalar { .. },
490        ) => FunctionType::Scalar,
491        objectiveai::functions::Function::Remote(
492            objectiveai::functions::RemoteFunction::Vector {
493                ref output_length,
494                ..
495            },
496        ) => {
497            let params = objectiveai::functions::expression::Params::Ref(
498                objectiveai::functions::expression::ParamsRef {
499                    input: &input,
500                    tasks: &EMPTY_TASKS,
501                    map: None,
502                },
503            );
504            FunctionType::Vector {
505                output_length: Some(
506                    output_length.clone().compile_one(&params)?,
507                ),
508            }
509        }
510        objectiveai::functions::Function::Inline(
511            objectiveai::functions::InlineFunction::Scalar { .. },
512        ) => FunctionType::Scalar,
513        objectiveai::functions::Function::Inline(
514            objectiveai::functions::InlineFunction::Vector { .. },
515        ) => FunctionType::Vector {
516            output_length: None,
517        },
518    };
519
520    // compile function tasks
521    let tasks = function.compile_tasks(&input)?;
522
523    // initialize flat tasks / futs vector
524    let mut flat_tasks_or_futs = Vec::with_capacity(tasks.len());
525
526    // iterate through tasks
527    for (i, (task, profile)) in tasks
528        .into_iter()
529        .zip(match profile {
530            objectiveai::functions::Profile::Remote(rp) => {
531                either::Either::Left(rp.tasks.into_iter())
532            }
533            objectiveai::functions::Profile::Inline(ip) => {
534                either::Either::Right(ip.tasks.into_iter())
535            }
536        })
537        .enumerate()
538    {
539        // if skip, push None to flat tasks
540        let task = match task {
541            Some(task) => task,
542            None => {
543                flat_tasks_or_futs.push(TaskFut::SkipTask);
544                continue;
545            }
546        };
547
548        // task path
549        let task_path = {
550            path.push(i as u64);
551            let p = path.clone();
552            path.pop();
553            p
554        };
555
556        // switch by task type
557        match task {
558            objectiveai::functions::CompiledTask::One(
559                objectiveai::functions::Task::ScalarFunction(
560                    objectiveai::functions::ScalarFunctionTask {
561                        owner,
562                        repository,
563                        commit,
564                        input,
565                    },
566                )
567                | objectiveai::functions::Task::VectorFunction(
568                    objectiveai::functions::VectorFunctionTask {
569                        owner,
570                        repository,
571                        commit,
572                        input,
573                    },
574                ),
575            ) => {
576                flat_tasks_or_futs.push(TaskFut::FunctionTaskFut(Box::pin(
577                    get_flat_task_profile(
578                        ctx.clone(),
579                        task_path,
580                        FunctionParam::Remote {
581                            owner,
582                            repository,
583                            commit: Some(commit),
584                        },
585                        match profile {
586                            objectiveai::functions::TaskProfile::RemoteFunction {
587                                owner,
588                                repository,
589                                commit,
590                            } => ProfileParam::Remote {
591                                owner,
592                                repository,
593                                commit,
594                            },
595                            objectiveai::functions::TaskProfile::InlineFunction(
596                                profile,
597                            ) => ProfileParam::FetchedOrInline {
598                                full_id: None,
599                                profile: objectiveai::functions::Profile::Inline(
600                                    profile,
601                                ),
602                            },
603                            _ => return Err(super::executions::Error::InvalidProfile),
604                        },
605                        input,
606                        function_fetcher.clone(),
607                        profile_fetcher.clone(),
608                        ensemble_fetcher.clone(),
609                    )
610                )));
611            }
612            objectiveai::functions::CompiledTask::One(
613                objectiveai::functions::Task::VectorCompletion(task),
614            ) => {
615                let (ensemble, profile) = match profile {
616                    objectiveai::functions::TaskProfile::VectorCompletion {
617                        ensemble,
618                        profile,
619                    } => (ensemble, profile),
620                    _ => return Err(super::executions::Error::InvalidProfile),
621                };
622                flat_tasks_or_futs.push(TaskFut::VectorTaskFut(Box::pin(
623                    get_vector_completion_flat_task_profile(
624                        ctx.clone(),
625                        task_path,
626                        task,
627                        ensemble,
628                        profile,
629                        ensemble_fetcher.clone(),
630                    ),
631                )));
632            }
633            objectiveai::functions::CompiledTask::Many(tasks)
634                if tasks.len() == 0 =>
635            {
636                flat_tasks_or_futs.push(TaskFut::Task(Some(
637                    FlatTaskProfile::MapVectorCompletion(
638                        MapVectorCompletionFlatTaskProfile {
639                            path: task_path,
640                            vector_completions: Vec::new(),
641                        },
642                    ),
643                )));
644            }
645            objectiveai::functions::CompiledTask::Many(tasks) => {
646                let vector_completions = match &tasks[0] {
647                    objectiveai::functions::Task::VectorCompletion(_) => true,
648                    _ => false,
649                };
650                if vector_completions {
651                    let mut futs = Vec::with_capacity(tasks.len());
652                    for (j, task) in tasks.into_iter().enumerate() {
653                        let mut task_path = task_path.clone();
654                        task_path.push(j as u64);
655                        let (ensemble, profile) = match &profile {
656                            objectiveai::functions::TaskProfile::VectorCompletion {
657                                ensemble,
658                                profile,
659                            } => (ensemble.clone(), profile.clone()),
660                            _ => return Err(super::executions::Error::InvalidProfile),
661                        };
662                        futs.push(get_vector_completion_flat_task_profile(
663                            ctx.clone(),
664                            task_path,
665                            match task {
666                                objectiveai::functions::Task::VectorCompletion(
667                                    vc_task,
668                                ) => vc_task,
669                                _ => unreachable!(),
670                            },
671                            ensemble,
672                            profile,
673                            ensemble_fetcher.clone(),
674                        ));
675                    }
676                    flat_tasks_or_futs.push(TaskFut::MapVectorTaskFut((
677                        task_path,
678                        futures::future::try_join_all(futs),
679                    )));
680                } else {
681                    let mut futs = Vec::with_capacity(tasks.len());
682                    for (j, task) in tasks.into_iter().enumerate() {
683                        let mut task_path = task_path.clone();
684                        task_path.push(j as u64);
685                        futs.push(get_flat_task_profile(
686                            ctx.clone(),
687                            task_path,
688                            FunctionParam::Remote {
689                                owner: match &task {
690                                    objectiveai::functions::Task::ScalarFunction(
691                                        sf_task,
692                                    ) => sf_task.owner.clone(),
693                                    objectiveai::functions::Task::VectorFunction(
694                                        vf_task,
695                                    ) => vf_task.owner.clone(),
696                                    _ => unreachable!(),
697                                },
698                                repository: match &task {
699                                    objectiveai::functions::Task::ScalarFunction(
700                                        sf_task,
701                                    ) => sf_task.repository.clone(),
702                                    objectiveai::functions::Task::VectorFunction(
703                                        vf_task,
704                                    ) => vf_task.repository.clone(),
705                                    _ => unreachable!(),
706                                },
707                                commit: Some(match &task {
708                                    objectiveai::functions::Task::ScalarFunction(
709                                        sf_task,
710                                    ) => sf_task.commit.clone(),
711                                    objectiveai::functions::Task::VectorFunction(
712                                        vf_task,
713                                    ) => vf_task.commit.clone(),
714                                    _ => unreachable!(),
715                                }),
716                            },
717                            match &profile {
718                                objectiveai::functions::TaskProfile::RemoteFunction {
719                                    owner,
720                                    repository,
721                                    commit,
722                                } => ProfileParam::Remote {
723                                    owner: owner.clone(),
724                                    repository: repository.clone(),
725                                    commit: commit.clone(),
726                                },
727                                objectiveai::functions::TaskProfile::InlineFunction(
728                                    profile,
729                                ) => ProfileParam::FetchedOrInline {
730                                    full_id: None,
731                                    profile: objectiveai::functions::Profile::Inline(
732                                        profile.clone(),
733                                    ),
734                                },
735                                _ => return Err(super::executions::Error::InvalidProfile),
736                            },
737                            match &task {
738                                objectiveai::functions::Task::ScalarFunction(
739                                    sf_task,
740                                ) => sf_task.input.clone(),
741                                objectiveai::functions::Task::VectorFunction(
742                                    vf_task,
743                                ) => vf_task.input.clone(),
744                                _ => unreachable!(),
745                            },
746                            function_fetcher.clone(),
747                            profile_fetcher.clone(),
748                            ensemble_fetcher.clone(),
749                        ));
750                    }
751                    flat_tasks_or_futs.push(TaskFut::MapFunctionTaskFut((
752                        task_path,
753                        futures::future::try_join_all(futs),
754                    )));
755                }
756            }
757        }
758    }
759
760    // await all futs
761    let tasks = futures::future::try_join_all(flat_tasks_or_futs).await?;
762
763    // return flat function task
764    Ok(super::FunctionFlatTaskProfile {
765        path,
766        description,
767        full_function_id: function_full_id,
768        full_profile_id: profile_full_id,
769        input,
770        tasks,
771        output,
772        r#type,
773    })
774}
775
776async fn get_vector_completion_flat_task_profile<CTXEXT>(
777    ctx: ctx::Context<CTXEXT>,
778    path: Vec<u64>,
779    task: objectiveai::functions::VectorCompletionTask,
780    ensemble: objectiveai::vector::completions::request::Ensemble,
781    profile: Vec<rust_decimal::Decimal>,
782    ensemble_fetcher: Arc<
783        crate::ensemble::fetcher::CachingFetcher<
784            CTXEXT,
785            impl crate::ensemble::fetcher::Fetcher<CTXEXT>
786            + Send
787            + Sync
788            + 'static,
789        >,
790    >,
791) -> Result<super::VectorCompletionFlatTaskProfile, super::executions::Error>
792where
793    CTXEXT: Send + Sync + 'static,
794{
795    // switch based on profile
796    let ensemble = match ensemble {
797        objectiveai::vector::completions::request::Ensemble::Id(id) => {
798            // fetch ensemble
799            ensemble_fetcher
800                .fetch(ctx, &id)
801                .map(|result| match result {
802                    Ok(Some((ensemble, _))) => Ok(ensemble),
803                    Ok(None) => Err(super::executions::Error::EnsembleNotFound),
804                    Err(e) => Err(super::executions::Error::FetchEnsemble(e)),
805                })
806                .await?
807        }
808        objectiveai::vector::completions::request::Ensemble::Provided(
809            ensemble,
810        ) => {
811            // validate ensemble
812            ensemble
813                .clone()
814                .try_into()
815                .map_err(super::executions::Error::InvalidEnsemble)?
816        }
817    };
818
819    // validate profile length
820    if profile.len() != ensemble.llms.len() {
821        return Err(super::executions::Error::InvalidProfile);
822    }
823
824    // construct flat task profile
825    Ok(super::VectorCompletionFlatTaskProfile {
826        path,
827        ensemble: objectiveai::ensemble::EnsembleBase {
828            llms: ensemble
829                .llms
830                .into_iter()
831                .map(|llm| {
832                    objectiveai::ensemble_llm::EnsembleLlmBaseWithFallbacksAndCount {
833                        count: llm.count,
834                        inner: llm.inner.base,
835                        fallbacks: llm.fallbacks.map(|fallbacks| {
836                            fallbacks
837                                .into_iter()
838                                .map(|fallback| fallback.base)
839                                .collect()
840                        }),
841                    }
842                })
843                .collect(),
844        },
845        profile,
846        messages: task.messages,
847        tools: task.tools,
848        responses: task.responses,
849    })
850}
851
852enum TaskFut<
853    VFUT: Future<
854        Output = Result<
855            super::VectorCompletionFlatTaskProfile,
856            super::executions::Error,
857        >,
858    >,
859    FFUT: Future<
860        Output = Result<
861            super::FunctionFlatTaskProfile,
862            super::executions::Error,
863        >,
864    >,
865> {
866    SkipTask,
867    Task(Option<super::FlatTaskProfile>),
868    VectorTaskFut(Pin<Box<VFUT>>),
869    MapVectorTaskFut((Vec<u64>, futures::future::TryJoinAll<VFUT>)),
870    FunctionTaskFut(Pin<Box<FFUT>>),
871    MapFunctionTaskFut((Vec<u64>, futures::future::TryJoinAll<FFUT>)),
872}
873
874impl<VFUT, FFUT> Future for TaskFut<VFUT, FFUT>
875where
876    VFUT: Future<
877        Output = Result<
878            super::VectorCompletionFlatTaskProfile,
879            super::executions::Error,
880        >,
881    >,
882    FFUT: Future<
883        Output = Result<
884            super::FunctionFlatTaskProfile,
885            super::executions::Error,
886        >,
887    >,
888{
889    type Output =
890        Result<Option<super::FlatTaskProfile>, super::executions::Error>;
891    fn poll(
892        self: Pin<&mut Self>,
893        cx: &mut std::task::Context<'_>,
894    ) -> Poll<Self::Output> {
895        match self.get_mut() {
896            TaskFut::SkipTask => Poll::Ready(Ok(None)),
897            TaskFut::Task(task) => Poll::Ready(Ok(task.take())),
898            TaskFut::VectorTaskFut(fut) => Pin::new(fut)
899                .poll(cx)
900                .map_ok(FlatTaskProfile::VectorCompletion)
901                .map_ok(Some),
902            TaskFut::MapVectorTaskFut((path, futs)) => {
903                Pin::new(futs).poll(cx).map_ok(|results| {
904                    Some(FlatTaskProfile::MapVectorCompletion(
905                        MapVectorCompletionFlatTaskProfile {
906                            path: path.clone(),
907                            vector_completions: results,
908                        },
909                    ))
910                })
911            }
912            TaskFut::FunctionTaskFut(fut) => Pin::new(fut)
913                .poll(cx)
914                .map_ok(FlatTaskProfile::Function)
915                .map_ok(Some),
916            TaskFut::MapFunctionTaskFut((path, futs)) => {
917                Pin::new(futs).poll(cx).map_ok(|results| {
918                    Some(FlatTaskProfile::MapFunction(
919                        MapFunctionFlatTaskProfile {
920                            path: path.clone(),
921                            functions: results,
922                        },
923                    ))
924                })
925            }
926        }
927    }
928}