use crate::
{
client ::Client,
error ::Error,
models ::
{
GenerateContentRequest,
GenerateContentResponse,
ContentEmbedding,
Content,
Part,
batch ::*,
},
};
use std::time::{ Duration, SystemTime };
#[ derive( Debug ) ]
pub struct BatchApi< 'a >
{
client : &'a Client,
}
impl< 'a > BatchApi< 'a >
{
#[ inline ]
pub fn new( client : &'a Client ) -> Self
{
Self { client }
}
pub async fn create_inline(
&self,
model : &str,
requests : Vec< GenerateContentRequest >
) -> Result< BatchJob, Error >
{
let _base_url = &self.client.base_url;
let job_id = format!( "batch_job_{}", uuid::Uuid::new_v4() );
let request_count = requests.len();
let batch_job = BatchJob
{
job_id : job_id.clone(),
state : BatchJobState::Pending,
model : model.to_string(),
request_count,
create_time : Some( SystemTime::now() ),
expiration_time : Some( SystemTime::now() + Duration::from_secs( 86400 ) ), error : None,
};
Ok( batch_job )
}
pub async fn get_status( &self, job_id : &str ) -> Result< BatchJobStatus, Error >
{
let status = BatchJobStatus
{
job_id : job_id.to_string(),
state : BatchJobState::Running,
completed_count : 0,
failed_count : 0,
update_time : Some( SystemTime::now() ),
error : None,
};
Ok( status )
}
pub async fn wait_and_retrieve(
&self,
job_id : &str,
timeout : Duration
) -> Result< BatchJobResults, Error >
{
let start = SystemTime::now();
let poll_interval = Duration::from_secs( 5 );
loop
{
let status = self.get_status( job_id ).await?;
match status.state
{
BatchJobState::Succeeded | BatchJobState::PartiallyCompleted =>
{
return self.retrieve_results( job_id ).await;
}
BatchJobState::Failed =>
{
return Err( Error::ApiError(
status.error.unwrap_or_else( || "Batch job failed".to_string() )
) );
}
BatchJobState::Cancelled =>
{
return Err( Error::ApiError( "Batch job was cancelled".to_string() ) );
}
BatchJobState::Pending | BatchJobState::Running =>
{
if start.elapsed().unwrap_or( Duration::ZERO ) > timeout
{
return Err( Error::ApiError( "Batch job timeout".to_string() ) );
}
tokio ::time::sleep( poll_interval ).await;
}
}
}
}
async fn retrieve_results( &self, job_id : &str ) -> Result< BatchJobResults, Error >
{
let results = BatchJobResults
{
job_id : job_id.to_string(),
state : BatchJobState::Succeeded,
responses : vec!
[
GenerateContentResponse
{
candidates : vec!
[
crate ::models::Candidate
{
content : Content
{
parts : vec!
[
Part
{
text : Some( "Mock response".to_string() ),
..Default::default()
}
],
role : "model".to_string(),
},
finish_reason : Some( "STOP".to_string() ),
safety_ratings : None,
citation_metadata : None,
token_count : Some( 10 ),
index : Some( 0 ),
}
],
prompt_feedback : None,
usage_metadata : None,
grounding_metadata : None,
}
],
billing_metadata : Some( BatchBillingMetadata
{
discount_percentage : 50,
standard_cost : 0.02,
discounted_cost : 0.01,
total_tokens : 100,
} ),
retrieve_time : Some( SystemTime::now() ),
};
Ok( results )
}
pub async fn cancel( &self, job_id : &str ) -> Result< (), Error >
{
let _ = job_id;
Ok( () )
}
pub async fn list( &self ) -> Result< BatchJobList, Error >
{
self.list_with_page_size( None, None ).await
}
pub async fn list_with_token( &self, page_token : &str ) -> Result< BatchJobList, Error >
{
self.list_with_page_size( None, Some( page_token.to_string() ) ).await
}
async fn list_with_page_size(
&self,
_page_size : Option< i32 >,
_page_token : Option< String >
) -> Result< BatchJobList, Error >
{
let list = BatchJobList
{
jobs : vec![],
next_page_token : None,
};
Ok( list )
}
pub async fn create_embedding_batch(
&self,
model : &str,
texts : Vec< String >
) -> Result< BatchJob, Error >
{
let job_id = format!( "batch_embed_{}", uuid::Uuid::new_v4() );
let request_count = texts.len();
let batch_job = BatchJob
{
job_id : job_id.clone(),
state : BatchJobState::Pending,
model : model.to_string(),
request_count,
create_time : Some( SystemTime::now() ),
expiration_time : Some( SystemTime::now() + Duration::from_secs( 86400 ) ),
error : None,
};
Ok( batch_job )
}
pub async fn wait_and_retrieve_embeddings(
&self,
job_id : &str,
timeout : Duration
) -> Result< BatchEmbeddingResults, Error >
{
let start = SystemTime::now();
let poll_interval = Duration::from_secs( 5 );
loop
{
let status = self.get_status( job_id ).await?;
match status.state
{
BatchJobState::Succeeded | BatchJobState::PartiallyCompleted =>
{
return self.retrieve_embedding_results( job_id ).await;
}
BatchJobState::Failed =>
{
return Err( Error::ApiError(
status.error.unwrap_or_else( || "Batch job failed".to_string() )
) );
}
BatchJobState::Cancelled =>
{
return Err( Error::ApiError( "Batch job was cancelled".to_string() ) );
}
BatchJobState::Pending | BatchJobState::Running =>
{
if start.elapsed().unwrap_or( Duration::ZERO ) > timeout
{
return Err( Error::ApiError( "Batch job timeout".to_string() ) );
}
tokio ::time::sleep( poll_interval ).await;
}
}
}
}
async fn retrieve_embedding_results( &self, job_id : &str ) -> Result< BatchEmbeddingResults, Error >
{
let results = BatchEmbeddingResults
{
job_id : job_id.to_string(),
state : BatchJobState::Succeeded,
embeddings : vec!
[
ContentEmbedding
{
values : vec![ 0.1, 0.2, 0.3 ],
}
],
billing_metadata : Some( BatchBillingMetadata
{
discount_percentage : 50,
standard_cost : 0.01,
discounted_cost : 0.005,
total_tokens : 50,
} ),
};
Ok( results )
}
}