crate::ix!();
pub type LanguageModelClientArc = Arc<dyn LanguageModelClientInterface<LanguageModelBatchWorkflowError>>;
pub trait ComputeSystemMessage {
fn system_message() -> String;
}
pub trait ComputeLanguageModelCoreQuery {
type Seed: HasAssociatedOutputName + Named;
fn compute_language_model_core_query(
&self,
input: &Self::Seed
) -> String;
}
#[async_trait]
pub trait FinishProcessingUncompletedBatches {
type Error;
async fn finish_processing_uncompleted_batches(
&self,
expected_content_type: &ExpectedContentType
) -> Result<(), Self::Error>;
}
pub trait ComputeLanguageModelRequests {
type Seed: HasAssociatedOutputName + Send + Sync;
fn compute_language_model_requests(
&self,
model: &LanguageModelType,
input_tokens: &[Self::Seed]
) -> Vec<LanguageModelBatchAPIRequest>;
}
#[async_trait]
pub trait ProcessBatchRequests {
type Error;
async fn process_batch_requests(
&self,
batch_requests: &[LanguageModelBatchAPIRequest],
expected_content_type: &ExpectedContentType,
) -> Result<(), Self::Error>;
}
#[async_trait]
pub trait LanguageModelBatchWorkflow<E: From<LanguageModelBatchCreationError>>:
FinishProcessingUncompletedBatches<Error = E>
+ ComputeLanguageModelRequests
+ ProcessBatchRequests<Error = E>
{
const REQUESTS_PER_BATCH: usize = 80;
async fn plant_seed_and_wait(
&mut self,
input_tokens: &[<Self as ComputeLanguageModelRequests>::Seed]
) -> Result<(), E>;
async fn execute_language_model_batch_workflow(
&mut self,
model: LanguageModelType,
expected_content_type: ExpectedContentType,
input_tokens: &[<Self as ComputeLanguageModelRequests>::Seed]
) -> Result<(), E>
{
info!("Beginning full batch workflow execution");
self.finish_processing_uncompleted_batches(&expected_content_type).await?;
let requests: Vec<_> = self.compute_language_model_requests(&model, input_tokens);
let enumerated_batches = construct_batches(&requests, Self::REQUESTS_PER_BATCH, false)?;
for (batch_idx, batch_requests) in enumerated_batches {
info!("Processing batch #{}", batch_idx);
self.process_batch_requests(batch_requests, &expected_content_type).await?;
}
Ok(())
}
}
#[async_trait]
pub trait LanguageModelBatchWorkflowGatherResults {
type Error;
type Seed: HasAssociatedOutputName + Clone + Named;
type Output: LoadFromFile<Error = SaveLoadError>;
async fn gather_results(
&self,
seeds: &[Self::Seed]
) -> Result<Vec<(Self::Seed, Self::Output)>, Self::Error>;
}