crate::ix!();
pub fn parse_derive_input_for_lmbw(ast: &DeriveInput) -> Result<LmbwParsedInput, Error> {
tracing::trace!("parse_derive_input_for_lmbw: start.");
let struct_ident = ast.ident.clone();
let generics = ast.generics.clone();
let mut custom_error_type: Option<Type> = None;
let mut json_output_format_type: Option<Type> = None;
for attr in &ast.attrs {
if attr.path().is_ident("batch_error_type") {
let parsed_ty = attr.parse_args::<Type>()?;
custom_error_type = Some(parsed_ty);
} else if attr.path().is_ident("batch_json_output_format") {
if json_output_format_type.is_some() {
return Err(Error::new_spanned(
attr,
"Duplicate #[batch_json_output_format(...)] attribute found.",
));
}
let parsed_ty = attr.parse_args::<Type>()?;
json_output_format_type = Some(parsed_ty);
}
}
let fields = match &ast.data {
syn::Data::Struct(ds) => match &ds.fields {
syn::Fields::Named(named) => &named.named,
_ => {
return Err(Error::new_spanned(
&ast.ident,
"LanguageModelBatchWorkflow derive only supports a named struct.",
));
}
},
_ => {
return Err(Error::new_spanned(
&ast.ident,
"LanguageModelBatchWorkflow derive can only be used on a struct.",
));
}
};
let mut batch_client_field: Option<syn::Ident> = None;
let mut batch_workspace_field: Option<syn::Ident> = None;
let mut model_type_field: Option<syn::Ident> = None;
let mut process_batch_output_fn_field: Option<syn::Ident> = None;
let mut process_batch_error_fn_field: Option<syn::Ident> = None;
for field in fields {
let field_ident = match &field.ident {
Some(id) => id.clone(),
None => continue,
};
for attr in &field.attrs {
if attr.path().is_ident("batch_client") {
batch_client_field = Some(field_ident.clone());
} else if attr.path().is_ident("batch_workspace") {
batch_workspace_field = Some(field_ident.clone());
} else if attr.path().is_ident("model_type") {
model_type_field = Some(field_ident.clone());
} else if attr.path().is_ident("custom_process_batch_output_fn") {
process_batch_output_fn_field = Some(field_ident.clone());
} else if attr.path().is_ident("custom_process_batch_error_fn") {
process_batch_error_fn_field = Some(field_ident.clone());
}
}
}
if batch_client_field.is_none() {
return Err(Error::new_spanned(
&ast.ident,
"Missing required `#[batch_client]` field.",
));
}
if batch_workspace_field.is_none() {
return Err(Error::new_spanned(
&ast.ident,
"Missing required `#[batch_workspace]` field.",
));
}
if model_type_field.is_none() {
return Err(Error::new_spanned(
&ast.ident,
"Missing required `#[model_type]` field.",
));
}
if custom_error_type.is_none() {
return Err(Error::new_spanned(
&ast.ident,
"Missing required `#[batch_error_type(...)]` attribute on the struct.",
));
}
let built = LmbwParsedInputBuilder::default()
.struct_ident(struct_ident)
.generics(generics)
.batch_client_field(batch_client_field)
.batch_workspace_field(batch_workspace_field)
.custom_error_type(custom_error_type)
.json_output_format_type(json_output_format_type)
.model_type_field(model_type_field)
.process_batch_output_fn_field(process_batch_output_fn_field)
.process_batch_error_fn_field(process_batch_error_fn_field)
.build()
.map_err(|e| {
Error::new_spanned(&ast.ident, format!("Builder error: {e}"))
})?;
Ok(built)
}
#[cfg(test)]
mod test_parse_derive_input_for_lmbw {
use super::*;
#[traced_test]
fn verifies_named_struct_parsing() {
info!("Starting verifies_named_struct_parsing test for parse_derive_input_for_lmbw.");
let ast: DeriveInput = parse_quote! {
#[batch_error_type(MyCustomError)]
struct Dummy {
#[batch_client]
some_client: std::sync::Arc<OpenAIClientHandle>,
#[batch_workspace]
some_workspace: std::sync::Arc<BatchWorkspace>,
#[custom_process_batch_output_fn]
pbo: BatchWorkflowProcessOutputFileFn,
#[custom_process_batch_error_fn]
pbe: BatchWorkflowProcessErrorFileFn,
#[model_type]
mt: LanguageModelType,
}
};
let parsed = match parse_derive_input_for_lmbw(&ast) {
Ok(x) => x,
Err(e) => {
panic!("Expected parse to succeed, but got error: {}", e);
}
};
assert!(
parsed.batch_client_field().is_some(),
"Should have found batch_client field."
);
assert!(
parsed.batch_workspace_field().is_some(),
"Should have found batch_workspace field."
);
assert!(
parsed.process_batch_output_fn_field().is_some(),
"Should have found custom_process_batch_output_fn field."
);
assert!(
parsed.process_batch_error_fn_field().is_some(),
"Should have found custom_process_batch_error_fn field."
);
assert!(
parsed.model_type_field().is_some(),
"Should have found model_type field."
);
}
#[traced_test]
fn handles_struct_lacking_attributes() {
info!("Starting handles_struct_lacking_attributes test for parse_derive_input_for_lmbw.");
let ast: DeriveInput = parse_quote! {
struct NoAttrs {
field_a: i32,
field_b: String,
}
};
let result = parse_derive_input_for_lmbw(&ast);
assert!(
result.is_err(),
"Should have error because we lack required fields (batch_client, batch_workspace, etc.)"
);
if let Err(e) = result {
info!("Got expected error: {}", e);
}
}
}