1use crate::{
4 error::Error,
5 proto::{Query as PluginQuery, QueryState},
6 types::Query,
7};
8use anyhow::{Result, anyhow};
9use std::result::Result as StdResult;
10
11pub const GRPC_MAX_SIZE_BYTES: usize = 1024 * 1024 * 4;
13
14const GRPC_EFFECTIVE_MAX_SIZE: usize = GRPC_MAX_SIZE_BYTES - 1024; #[derive(Clone, Debug)]
17enum DrainedString {
18 CompleteString(String),
20 PartialString {
22 drained_portion: String,
25 remainder: String,
27 },
28}
29
30fn drain_at_most_n_bytes(mut buf: String, max: usize) -> DrainedString {
35 let mut to_drain = std::cmp::min(buf.len(), max);
36 if buf.len() <= to_drain {
37 return DrainedString::CompleteString(buf);
38 }
39 while to_drain > 0 && !buf.is_char_boundary(to_drain) {
40 to_drain -= 1;
41 }
42 let drained_portion = buf.drain(0..to_drain).collect::<String>();
43 let remainder = buf;
44 DrainedString::PartialString {
45 drained_portion,
46 remainder,
47 }
48}
49
50fn all_chunkable_data_consumed(msg: &PluginQuery) -> bool {
55 msg.key.is_empty() && msg.output.is_empty() && msg.concern.is_empty()
56}
57
58pub fn chunk_with_size(msg: PluginQuery, max_est_size: usize) -> Result<Vec<PluginQuery>> {
59 let (in_progress_state, completion_state) = match msg.state() {
64 QueryState::Unspecified => return Err(anyhow!("msg in Unspecified query state")),
66 QueryState::SubmitInProgress | QueryState::SubmitComplete => {
67 (QueryState::SubmitInProgress, QueryState::SubmitComplete)
68 }
69 QueryState::ReplyInProgress | QueryState::ReplyComplete => {
70 (QueryState::ReplyInProgress, QueryState::ReplyComplete)
71 }
72 };
73
74 let null_key = msg.key.is_empty();
75 let null_output = msg.output.is_empty();
76
77 let mut out: Vec<PluginQuery> = vec![];
78 let mut base: PluginQuery = msg;
79
80 let mut made_progress = true;
82 while !all_chunkable_data_consumed(&base) {
83 if !made_progress {
84 return Err(anyhow!("Message could not be chunked"));
85 }
86 made_progress = false;
87
88 let mut remaining = max_est_size;
91 let mut chunked_query = PluginQuery {
92 id: base.id,
93 state: in_progress_state as i32,
94 publisher_name: base.publisher_name.clone(),
95 plugin_name: base.plugin_name.clone(),
96 query_name: base.query_name.clone(),
97 key: vec![],
98 output: vec![],
99 concern: vec![],
100 split: false,
101 };
102
103 for (source, sink) in [
104 (&mut base.key, &mut chunked_query.key),
105 (&mut base.output, &mut chunked_query.output),
106 (&mut base.concern, &mut chunked_query.concern),
107 ] {
108 let split_occurred = drain_vec_string(source, sink, &mut remaining, &mut made_progress);
109 if split_occurred {
110 chunked_query.split = true;
111 break;
112 }
113 if remaining == 0 {
114 break;
115 }
116 }
117
118 if cfg!(feature = "rfd9-compat") {
120 if chunked_query.key.is_empty() {
122 chunked_query.key.push("".to_owned());
123 }
124 if chunked_query.output.is_empty() {
126 chunked_query.output.push("".to_owned());
127 }
128 }
129
130 out.push(chunked_query);
131 }
132
133 if let Some(last) = out.last_mut() {
135 last.state = completion_state as i32;
136 }
137
138 if (cfg!(feature = "rfd9-compat") && (null_key || null_output))
142 && let Some(first) = out.first_mut()
143 {
144 if null_key && let Some(k) = first.key.first_mut() {
145 *k = "null".to_owned()
146 }
147 if null_output && let Some(o) = first.output.first_mut() {
148 *o = "null".to_owned()
149 }
150 }
151
152 Ok(out)
153}
154
155pub fn chunk(msg: PluginQuery) -> Result<Vec<PluginQuery>> {
156 chunk_with_size(msg, GRPC_EFFECTIVE_MAX_SIZE)
157}
158
159pub fn prepare(msg: Query) -> Result<Vec<PluginQuery>> {
160 chunk(msg.try_into()?)
161}
162
163fn drain_vec_string(
170 source: &mut Vec<String>,
171 sink: &mut Vec<String>,
172 remaining: &mut usize,
173 made_progress: &mut bool,
174) -> bool {
175 while !source.is_empty() {
176 let s_to_drain = source.remove(0);
180 let drained_str = drain_at_most_n_bytes(s_to_drain, *remaining);
181 match drained_str {
182 DrainedString::CompleteString(complete) => {
183 *made_progress = true;
184 *remaining -= complete.len();
185 sink.push(complete);
186 }
187 DrainedString::PartialString {
188 drained_portion,
189 remainder,
190 } => {
191 let split = !drained_portion.is_empty();
193 if split {
194 *made_progress = true;
195 *remaining -= drained_portion.len();
196 sink.push(drained_portion);
197 }
198 source.insert(0, remainder);
201 return split;
202 }
203 }
204 }
205 false
206}
207
208fn in_progress_state(state: &QueryState) -> bool {
210 matches!(
211 state,
212 QueryState::ReplyInProgress | QueryState::SubmitInProgress
213 )
214}
215
216#[derive(Debug)]
218enum QueryVecField {
219 Key,
220 Output,
221 Concern,
222}
223
224fn last_field_to_have_content(query: &PluginQuery) -> QueryVecField {
231 if !query.concern.is_empty() {
232 return QueryVecField::Concern;
233 }
234 if cfg!(feature = "rfd9-compat") {
238 if !(query.output.len() == 1
239 && (query.output.first().unwrap() == "" || query.output.first().unwrap() == "null"))
240 {
241 return QueryVecField::Output;
242 }
243 } else if !query.output.is_empty() {
244 return QueryVecField::Output;
245 }
246 QueryVecField::Key
247}
248
249#[derive(Default)]
250pub struct QuerySynthesizer {
251 raw: Option<PluginQuery>,
252}
253
254impl QuerySynthesizer {
255 pub fn add<I>(&mut self, mut chunks: I) -> StdResult<Option<Query>, Error>
256 where
257 I: Iterator<Item = PluginQuery>,
258 {
259 if self.raw.is_none() {
260 self.raw = match chunks.next() {
261 Some(x) => Some(x),
262 None => {
263 return Ok(None);
264 }
265 };
266 }
267 let raw = self.raw.as_mut().unwrap(); let initial_state: QueryState = raw
270 .state
271 .try_into()
272 .map_err(|_| Error::UnspecifiedQueryState)?;
273 let mut current_state: QueryState = initial_state;
275
276 let mut last_message_split: Option<QueryVecField> = if raw.split {
279 Some(last_field_to_have_content(raw))
280 } else {
281 None
282 };
283
284 if in_progress_state(¤t_state) {
286 while in_progress_state(¤t_state) {
287 let mut next = match chunks.next() {
290 Some(msg) => msg,
291 None => {
292 return Ok(None);
293 }
294 };
295
296 current_state = next
298 .state
299 .try_into()
300 .map_err(|_| Error::UnspecifiedQueryState)?;
301 match (initial_state, current_state) {
302 (QueryState::Unspecified, _)
304 | (QueryState::ReplyComplete, _)
305 | (QueryState::SubmitComplete, _) => {
306 unreachable!()
307 }
308
309 (_, QueryState::Unspecified) => return Err(Error::UnspecifiedQueryState),
311 (QueryState::SubmitInProgress, QueryState::ReplyInProgress)
313 | (QueryState::SubmitInProgress, QueryState::ReplyComplete) => {
314 return Err(Error::ReceivedReplyWhenExpectingSubmitChunk);
315 }
316 (QueryState::ReplyInProgress, QueryState::SubmitInProgress)
318 | (QueryState::ReplyInProgress, QueryState::SubmitComplete) => {
319 return Err(Error::ReceivedSubmitWhenExpectingReplyChunk);
320 }
321 (_, _) => {
323 if current_state == QueryState::ReplyComplete {
324 raw.set_state(QueryState::ReplyComplete);
325 }
326 if current_state == QueryState::SubmitComplete {
327 raw.set_state(QueryState::SubmitComplete);
328 }
329
330 let next_message_split = if next.split {
331 Some(last_field_to_have_content(&next))
332 } else {
333 None
334 };
335
336 if let Some(split_field) = last_message_split {
344 match split_field {
345 QueryVecField::Key => {
346 raw.key
347 .last_mut()
348 .unwrap()
349 .push_str(next.key.remove(0).as_str());
350 }
351 QueryVecField::Output => {
352 raw.output
353 .last_mut()
354 .unwrap()
355 .push_str(next.output.remove(0).as_str());
356 }
357 QueryVecField::Concern => {
358 raw.concern
359 .last_mut()
360 .unwrap()
361 .push_str(next.concern.remove(0).as_str());
362 }
363 }
364 }
365
366 raw.key.extend(next.key);
367 raw.output.extend(next.output);
368 raw.concern.extend(next.concern);
369
370 last_message_split = next_message_split;
372 }
373 };
374 }
375
376 if chunks.next().is_some() {
378 return Err(Error::MoreAfterQueryComplete {
379 id: raw.id as usize,
380 });
381 }
382 }
383
384 self.raw.take().unwrap().try_into().map(Some)
385 }
386}
387
388#[cfg(test)]
389mod test {
390
391 use super::*;
392
393 #[test]
394 fn test_bounded_char_draining() {
395 let orig_key = "aこれは実験です".to_owned();
396 let max_size = 10;
397 let res = drain_at_most_n_bytes(orig_key.clone(), max_size);
398 let (drained, remainder) = match res {
399 DrainedString::CompleteString(_) => panic!("expected to return PartialString"),
400 DrainedString::PartialString {
401 drained_portion,
402 remainder,
403 } => (drained_portion, remainder),
404 };
405 assert!((0..=max_size).contains(&drained.len()));
406
407 let mut reassembled = drained;
409 reassembled.push_str(remainder.as_str());
410 assert_eq!(orig_key, reassembled);
411 }
412
413 #[test]
415 fn test_draining_vec() {
416 let mut source = vec!["123456".to_owned()];
417 let mut sink = vec![];
418
419 while !source.is_empty() {
421 let mut made_progress = false;
422 let partial = drain_vec_string(&mut source, &mut sink, &mut 1, &mut made_progress);
423 assert_eq!(partial, !source.is_empty())
425 }
426 assert_eq!(sink.len(), 6);
427 assert!(source.is_empty());
428
429 let mut source = vec!["123456".to_owned()];
431 let mut sink = vec![];
432 while !source.is_empty() {
433 let mut made_progress = false;
434 let partial = drain_vec_string(&mut source, &mut sink, &mut 3, &mut made_progress);
435 assert_eq!(partial, !source.is_empty())
437 }
438 assert_eq!(sink.len(), 2);
439 assert!(source.is_empty());
440 }
441
442 #[test]
444 fn test_char_boundary_respected() {
445 let mut source = vec!["実".to_owned()];
446 let mut sink = vec![];
447 let mut made_progress = false;
448 drain_vec_string(&mut source, &mut sink, &mut 1, &mut made_progress);
451 assert!(!made_progress);
452 }
453
454 #[test]
456 fn test_non_ascii_drain_vec_string_makes_progress() {
457 let mut source = vec!["1234".to_owned(), "aこれ".to_owned(), "abcdef".to_owned()];
458 let mut sink = vec![];
459
460 while !source.is_empty() {
461 let remaining = &mut 4;
463 let made_progress = &mut false;
464 drain_vec_string(&mut source, &mut sink, remaining, made_progress);
465 assert!(*made_progress);
466 }
467 assert_eq!(sink.first().unwrap(), "1234");
469 assert!(source.is_empty());
470 }
471
472 #[test]
473 fn test_drain_vec_string_split_detection() {
474 let mut max_len = 3;
477 let mut source = vec!["1234".to_owned()];
478 let mut sink = vec![];
479 let mut made_progress = false;
480 let split = drain_vec_string(&mut source, &mut sink, &mut max_len, &mut made_progress);
481 assert!(split);
482 assert_eq!(source, vec!["4"]);
483 assert!(made_progress);
484 assert_eq!(source.len(), 1);
485 assert_eq!(sink.len(), 1);
486
487 let mut max_len = 10;
490 let mut source = vec!["123456789".to_owned()];
491 let mut sink = vec![];
492 let mut made_progress = false;
493 let split = drain_vec_string(&mut source, &mut sink, &mut max_len, &mut made_progress);
494 assert!(!split);
495 assert!(source.is_empty());
496 assert!(made_progress);
497 assert_eq!(sink.len(), 1);
498 }
499
500 #[test]
501 fn test_chunking_and_query_reconstruction() {
502 let states = [
504 (QueryState::SubmitInProgress, QueryState::SubmitComplete),
505 (QueryState::ReplyInProgress, QueryState::ReplyComplete),
506 ];
507
508 for (intermediate_state, final_state) in states.into_iter() {
509 let output = if cfg!(feature = "rfd9-compat") {
510 vec!["null".to_owned()]
511 } else {
512 vec![]
513 };
514 let orig_query = PluginQuery {
515 id: 0,
516 state: final_state as i32,
517 publisher_name: "".to_owned(),
518 plugin_name: "".to_owned(),
519 query_name: "".to_owned(),
520 key: vec![serde_json::to_string("aこれは実験です").unwrap()],
522 output,
523 concern: vec![
524 "< 10".to_owned(),
525 "0123456789".to_owned(),
526 "< 10#2".to_owned(),
527 ],
528 split: false,
529 };
530 let res = match chunk_with_size(orig_query.clone(), 10) {
531 Ok(r) => r,
532 Err(e) => {
533 panic!("chunk_with_size unexpectedly errored: {e}");
534 }
535 };
536
537 res[..res.len() - 1]
539 .iter()
540 .for_each(|x| assert_eq!(x.state(), intermediate_state));
541 assert_eq!(res.last().unwrap().state(), final_state);
543
544 let mut synth = QuerySynthesizer::default();
546 let synthesized_query = synth.add(res.into_iter()).unwrap();
547
548 let synthesized_plugin_query: PluginQuery =
549 synthesized_query.unwrap().try_into().unwrap();
550 assert_eq!(orig_query, synthesized_plugin_query);
551 }
552 }
553}