use crate::ElementList;
use reqwest::multipart::Form;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChunkingStrategy {
Basic,
ByPage,
BySimilarity,
ByTitle,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Strategy {
Fast,
HiRes,
Auto,
OcrOnly,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub enum OutputFormat {
#[serde(rename = "application/json")]
ApplicationJson,
#[serde(rename = "text/csv")]
TextCsv,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PartitionParameters {
pub coordinates: bool,
pub encoding: Option<String>,
pub extract_image_block_types: Vec<String>,
pub gz_uncompressed_content_type: Option<String>,
pub hi_res_model_name: Option<String>,
pub include_page_breaks: bool,
pub languages: Option<Vec<String>>,
pub output_format: String,
pub skip_infer_table_types: Vec<String>,
pub starting_page_number: Option<i32>,
pub strategy: Strategy,
pub unique_element_ids: bool,
pub xml_keep_tags: bool,
pub chunking_strategy: Option<ChunkingStrategy>,
pub combine_under_n_chars: Option<i32>,
pub include_orig_elements: bool,
pub max_characters: Option<i32>,
pub multipage_sections: bool,
pub new_after_n_chars: Option<i32>,
pub overlap: i32,
pub overlap_all: bool,
pub similarity_threshold: Option<f64>,
}
impl Default for PartitionParameters {
fn default() -> Self {
PartitionParameters {
coordinates: false,
encoding: Some("utf-8".to_string()),
extract_image_block_types: vec![],
gz_uncompressed_content_type: None,
hi_res_model_name: None,
include_page_breaks: false,
languages: None,
output_format: "application/json".to_string(),
skip_infer_table_types: vec![],
starting_page_number: None,
strategy: Strategy::Auto,
unique_element_ids: false,
xml_keep_tags: false,
chunking_strategy: None,
combine_under_n_chars: None,
include_orig_elements: true,
max_characters: None,
multipage_sections: true,
new_after_n_chars: None,
overlap: 0,
overlap_all: false,
similarity_threshold: None,
}
}
}
impl From<PartitionParameters> for Form {
fn from(value: PartitionParameters) -> Self {
let mut form = Form::new();
form = form.text("coordinates", value.coordinates.to_string());
if let Some(encoding) = value.encoding.clone() {
form = form.text("encoding", encoding);
}
form = form.text(
"extract_image_block_types",
serde_json::to_string(&value.extract_image_block_types).unwrap(),
);
if let Some(gz_uncompressed_content_type) = value.gz_uncompressed_content_type.clone() {
form = form.text("gz_uncompressed_content_type", gz_uncompressed_content_type);
}
if let Some(hi_res_model_name) = value.hi_res_model_name.clone() {
form = form.text("hi_res_model_name", hi_res_model_name);
}
form = form.text("include_page_breaks", value.include_page_breaks.to_string());
if let Some(languages) = value.languages.clone() {
form = form.text("languages", serde_json::to_string(&languages).unwrap());
}
form = form.text("output_format", value.output_format.clone());
form = form.text(
"skip_infer_table_types",
serde_json::to_string(&value.skip_infer_table_types).unwrap(),
);
if let Some(starting_page_number) = value.starting_page_number {
form = form.text("starting_page_number", starting_page_number.to_string());
}
form = form.text("strategy", {
let s = String::from(
serde_json::to_string(&value.strategy)
.expect("Could not convert Strategy enum to string.")
.trim_matches('"'),
);
s
});
form = form.text("unique_element_ids", value.unique_element_ids.to_string());
form = form.text("xml_keep_tags", value.xml_keep_tags.to_string());
if let Some(chunking_strategy) = value
.chunking_strategy
.as_ref()
.map(serde_json::to_string)
.transpose()
.expect("Could not convert Chunking Strategy enum to string.")
{
form = form.text(
"chunking_strategy",
chunking_strategy.trim_matches('"').to_string(),
);
}
if let Some(combine_under_n_chars) = value.combine_under_n_chars {
form = form.text("combine_under_n_chars", combine_under_n_chars.to_string());
}
form = form.text(
"include_orig_elements",
value.include_orig_elements.to_string(),
);
if let Some(max_characters) = value.max_characters {
form = form.text("max_characters", max_characters.to_string());
}
form = form.text("multipage_sections", value.multipage_sections.to_string());
if let Some(new_after_n_chars) = value.new_after_n_chars {
form = form.text("new_after_n_chars", new_after_n_chars.to_string());
}
form = form.text("overlap", value.overlap.to_string());
form = form.text("overlap_all", value.overlap_all.to_string());
form
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum LocElement {
Str(String),
Int(i64),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ValidationError {
pub loc: Vec<LocElement>,
pub msg: String,
pub r#type: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum PartitionResponse {
Success(ElementList),
ValidationFailure(ValidationError),
UnknownFailure(serde_json::Value),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_partition_params() {
let params = PartitionParameters::default();
println!("{:?}", params)
}
#[test]
fn test_deserialize_chunking_strategy() {
let json = r#""basic""#;
let strategy: ChunkingStrategy = serde_json::from_str(json).unwrap();
assert_eq!(strategy, ChunkingStrategy::Basic);
}
#[test]
fn test_deserialize_strategy() {
let json = r#""auto""#;
let strategy: Strategy = serde_json::from_str(json).unwrap();
assert_eq!(strategy, Strategy::Auto);
}
#[test]
fn test_deserialize_output_format() {
let json = r#""application/json""#;
let format: OutputFormat = serde_json::from_str(json).unwrap();
assert_eq!(format, OutputFormat::ApplicationJson);
}
#[test]
fn test_deserialize_partition_parameters() {
let json = r#"{
"coordinates": true,
"encoding": "utf-8",
"extract_image_block_types": [],
"gz_uncompressed_content_type": null,
"hi_res_model_name": null,
"include_page_breaks": true,
"languages": null,
"output_format": "application/json",
"skip_infer_table_types": [],
"starting_page_number": null,
"strategy": "auto",
"unique_element_ids": false,
"xml_keep_tags": false,
"chunking_strategy": null,
"combine_under_n_chars": null,
"include_orig_elements": true,
"max_characters": null,
"multipage_sections": true,
"new_after_n_chars": null,
"overlap": 0,
"overlap_all": false,
"similarity_threshold": null
}"#;
let params: PartitionParameters = serde_json::from_str(json).unwrap();
assert_eq!(params.coordinates, true);
assert_eq!(params.encoding.unwrap(), "utf-8");
assert_eq!(params.include_page_breaks, true);
assert_eq!(params.output_format, "application/json".to_string());
assert_eq!(params.include_orig_elements, true);
assert_eq!(params.multipage_sections, true);
assert_eq!(params.overlap, 0);
assert_eq!(params.overlap_all, false);
}
}