use crate::bareun::stream_correct_error_response::Res as StreamRes;
use crate::bareun::{
CancelledRevision, CorrectErrorRequest, CorrectErrorResponse, Document, EncodingType,
PostRevision, ProgressRevision, RevisionConfig, StreamCorrectErrorRequest,
StreamCorrectErrorResponse, StreamFirstCorrectError,
};
use crate::error::Result;
use crate::revision_service_client::BareunRevisionServiceClient;
use tonic::Streaming;
#[derive(Default, Clone)]
pub struct RevisionConfigBuilder {
config: RevisionConfig,
}
impl RevisionConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn disable_split_sentence(mut self, v: bool) -> Self {
self.config.disable_split_sentence = v;
self
}
pub fn disable_caret_spacing(mut self, v: bool) -> Self {
self.config.disable_caret_spacing = v;
self
}
pub fn disable_vx_spacing(mut self, v: bool) -> Self {
self.config.disable_vx_spacing = v;
self
}
pub fn treat_as_title(mut self, v: bool) -> Self {
self.config.treat_as_title = v;
self
}
pub fn enable_limited_punctuation(mut self, v: bool) -> Self {
self.config.enable_limited_punctuation = v;
self
}
pub fn disable_confusion(mut self, v: bool) -> Self {
self.config.disable_confusion = v;
self
}
pub fn enable_cleanup_whitespace(mut self, v: bool) -> Self {
self.config.enable_cleanup_whitespace = v;
self
}
pub fn disable_typo_correction(mut self, v: bool) -> Self {
self.config.disable_typo_correction = v;
self
}
pub fn enable_sentence_check(mut self, v: bool) -> Self {
self.config.enable_sentence_check = v;
self
}
pub fn build(self) -> RevisionConfig {
self.config
}
}
pub enum StreamRevisionEvent {
First(StreamFirstCorrectError),
Cancelled(CancelledRevision),
Post(PostRevision),
Progress(ProgressRevision),
}
impl StreamRevisionEvent {
pub fn from_message(msg: StreamCorrectErrorResponse) -> Option<Self> {
msg.res.map(|r| match r {
StreamRes::First(v) => StreamRevisionEvent::First(v),
StreamRes::Cancelled(v) => StreamRevisionEvent::Cancelled(v),
StreamRes::Post(v) => StreamRevisionEvent::Post(v),
StreamRes::Progress(v) => StreamRevisionEvent::Progress(v),
})
}
}
pub struct Corrector {
pub client: BareunRevisionServiceClient,
}
impl Corrector {
pub async fn new(apikey: &str, host: &str, port: Option<u16>) -> Result<Self> {
let client = BareunRevisionServiceClient::new(apikey, host, port).await?;
Ok(Corrector { client })
}
pub async fn correct_error_with(
&mut self,
content: &str,
custom_dicts: &[String],
builder: RevisionConfigBuilder,
) -> Result<CorrectErrorResponse> {
self.correct_error(content, custom_dicts, Some(builder.build()))
.await
}
pub async fn stream_correct_error_builder(
&mut self,
content: &str,
custom_dicts: &[String],
builder: RevisionConfigBuilder,
req_id: i64,
) -> Result<Streaming<StreamCorrectErrorResponse>> {
self.stream_correct_error(content, custom_dicts, Some(builder.build()), req_id)
.await
}
pub async fn correct_error(
&mut self,
content: &str,
custom_dicts: &[String],
config: Option<RevisionConfig>,
) -> Result<CorrectErrorResponse> {
#[allow(deprecated)]
let request = CorrectErrorRequest {
document: Some(Document {
content: content.to_string(),
language: "ko_KR".to_string(),
}),
encoding_type: EncodingType::Utf32.into(),
custom_domain: String::new(), custom_dict_names: custom_dicts.to_vec(),
config,
};
self.client.correct_error(request).await
}
pub fn print_results(&self, res: &CorrectErrorResponse) {
println!("원문: {}", res.origin);
println!("교정: {}", res.revised);
println!("\n=== 교정된 문장들 ===");
for sent in &res.revised_sentences {
println!(" 원문: {}", sent.origin);
println!("교정문: {}", sent.revised);
}
for block in &res.revised_blocks {
if let Some(origin) = &block.origin {
println!(
"원문:{} offset:{}, length:{}",
origin.content, origin.begin_offset, origin.length
);
}
println!("대표 교정: {}", block.revised);
if let Some(tc) = block.thinking_count {
println!(" 생각 중인 교정 수: {}", tc);
}
for rev in &block.revisions {
let help_text = res
.helps
.get(&rev.help_id)
.map(|h| h.comment.as_str())
.unwrap_or("");
let thinking = rev
.thinking_id
.map(|id| format!(", thinking_id:{}", id))
.unwrap_or_default();
println!(
" 교정: {}, 카테고리:{}{}, 도움말 {}",
rev.revised, rev.category, thinking, help_text
);
}
}
for cleanup in &res.whitespace_cleanup_ranges {
println!(
"공백제거: offset:{} length:{} position: {}",
cleanup.offset, cleanup.length, cleanup.position
);
}
}
pub async fn stream_correct_error(
&mut self,
content: &str,
custom_dicts: &[String],
config: Option<RevisionConfig>,
req_id: i64,
) -> Result<Streaming<StreamCorrectErrorResponse>> {
#[allow(deprecated)]
let request = StreamCorrectErrorRequest {
document: Some(Document {
content: content.to_string(),
language: "ko_KR".to_string(),
}),
encoding_type: EncodingType::Utf32.into(),
custom_domain: String::new(),
custom_dict_names: custom_dicts.to_vec(),
config,
req_id,
};
self.client.stream_correct_error(request).await
}
pub async fn stream_correct_error_with<F>(
&mut self,
content: &str,
custom_dicts: &[String],
config: Option<RevisionConfig>,
req_id: i64,
mut on_event: F,
) -> Result<()>
where
F: FnMut(StreamRevisionEvent) -> bool,
{
let mut stream = self
.stream_correct_error(content, custom_dicts, config, req_id)
.await?;
while let Some(msg) = stream.message().await? {
if let Some(event) = StreamRevisionEvent::from_message(msg) {
if !on_event(event) {
break;
}
}
}
Ok(())
}
pub fn as_json_str(&self, response: &CorrectErrorResponse) -> Result<String> {
Ok(serde_json::to_string_pretty(response)?)
}
pub fn print_as_json(&self, response: &CorrectErrorResponse) -> Result<()> {
println!("{}", self.as_json_str(response)?);
Ok(())
}
}