use crate::core::providers::unified_provider::ProviderError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateBatchJobRequest {
pub job_name: String,
pub model_id: String,
pub input_data_config: InputDataConfig,
pub output_data_config: OutputDataConfig,
pub role_arn: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<Tag>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_duration_in_hours: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InputDataConfig {
pub s3_input_data_config: S3InputDataConfig,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct S3InputDataConfig {
pub s3_uri: String,
pub s3_input_format: String, }
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OutputDataConfig {
pub s3_output_data_config: S3OutputDataConfig,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct S3OutputDataConfig {
pub s3_uri: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub s3_encryption_key_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Tag {
pub key: String,
pub value: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchJobResponse {
pub job_arn: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchJobDetails {
pub job_arn: String,
pub job_name: String,
pub status: BatchJobStatus,
pub model_id: String,
pub input_data_config: InputDataConfig,
pub output_data_config: OutputDataConfig,
pub role_arn: String,
pub created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_modified_at: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_time: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub job_expiration_time: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_duration_in_hours: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub job_statistics: Option<JobStatistics>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum BatchJobStatus {
Submitted,
InProgress,
Completed,
Failed,
Stopping,
Stopped,
PartiallyCompleted,
Expired,
Validating,
Scheduled,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct JobStatistics {
pub input_token_count: Option<u64>,
pub output_token_count: Option<u64>,
}
pub struct BatchClient<'a> {
client: &'a crate::core::providers::bedrock::client::BedrockClient,
}
impl<'a> BatchClient<'a> {
pub fn new(client: &'a crate::core::providers::bedrock::client::BedrockClient) -> Self {
Self { client }
}
pub async fn create_job(
&self,
request: CreateBatchJobRequest,
) -> Result<BatchJobResponse, ProviderError> {
let response = self
.client
.send_request("", "model-invocation-job", &serde_json::to_value(request)?)
.await?;
let job_response: BatchJobResponse = response
.json()
.await
.map_err(|e| ProviderError::response_parsing("bedrock", e.to_string()))?;
Ok(job_response)
}
pub async fn get_job(&self, job_identifier: &str) -> Result<BatchJobDetails, ProviderError> {
let url = format!("model-invocation-job/{}", job_identifier);
let response = self.client.send_get_request(&url).await?;
let job_details: BatchJobDetails = response
.json()
.await
.map_err(|e| ProviderError::response_parsing("bedrock", e.to_string()))?;
Ok(job_details)
}
pub async fn stop_job(&self, job_identifier: &str) -> Result<(), ProviderError> {
let url = format!("model-invocation-job/{}/stop", job_identifier);
self.client.send_request("", &url, &Value::Null).await?;
Ok(())
}
}