1use crate::{
4 error::Error,
5 proto::{Query as PluginQuery, QueryState},
6 types::Query,
7};
8use anyhow::{anyhow, Result};
9use std::result::Result as StdResult;
10
11pub const GRPC_MAX_SIZE_BYTES: usize = 1024 * 1024 * 4;
13
14const GRPC_EFFECTIVE_MAX_SIZE: usize = GRPC_MAX_SIZE_BYTES - 1024; #[derive(Clone, Debug)]
17enum DrainedString {
18 CompleteString(String),
20 PartialString {
22 drained_portion: String,
25 remainder: String,
27 },
28}
29
30fn drain_at_most_n_bytes(mut buf: String, max: usize) -> DrainedString {
35 let mut to_drain = std::cmp::min(buf.len(), max);
36 if buf.len() <= to_drain {
37 return DrainedString::CompleteString(buf);
38 }
39 while to_drain > 0 && !buf.is_char_boundary(to_drain) {
40 to_drain -= 1;
41 }
42 let drained_portion = buf.drain(0..to_drain).collect::<String>();
43 let remainder = buf;
44 DrainedString::PartialString {
45 drained_portion,
46 remainder,
47 }
48}
49
50fn all_chunkable_data_consumed(msg: &PluginQuery) -> bool {
55 msg.key.is_empty() && msg.output.is_empty() && msg.concern.is_empty()
56}
57
58pub fn chunk_with_size(msg: PluginQuery, max_est_size: usize) -> Result<Vec<PluginQuery>> {
59 let (in_progress_state, completion_state) = match msg.state() {
64 QueryState::Unspecified => return Err(anyhow!("msg in Unspecified query state")),
66 QueryState::SubmitInProgress | QueryState::SubmitComplete => {
67 (QueryState::SubmitInProgress, QueryState::SubmitComplete)
68 }
69 QueryState::ReplyInProgress | QueryState::ReplyComplete => {
70 (QueryState::ReplyInProgress, QueryState::ReplyComplete)
71 }
72 };
73
74 let null_key = msg.key.is_empty();
75 let null_output = msg.output.is_empty();
76
77 let mut out: Vec<PluginQuery> = vec![];
78 let mut base: PluginQuery = msg;
79
80 let mut made_progress = true;
82 while !all_chunkable_data_consumed(&base) {
83 if !made_progress {
84 return Err(anyhow!("Message could not be chunked"));
85 }
86 made_progress = false;
87
88 let mut remaining = max_est_size;
91 let mut chunked_query = PluginQuery {
92 id: base.id,
93 state: in_progress_state as i32,
94 publisher_name: base.publisher_name.clone(),
95 plugin_name: base.plugin_name.clone(),
96 query_name: base.query_name.clone(),
97 key: vec![],
98 output: vec![],
99 concern: vec![],
100 split: false,
101 };
102
103 for (source, sink) in [
104 (&mut base.key, &mut chunked_query.key),
105 (&mut base.output, &mut chunked_query.output),
106 (&mut base.concern, &mut chunked_query.concern),
107 ] {
108 let split_occurred = drain_vec_string(source, sink, &mut remaining, &mut made_progress);
109 if split_occurred {
110 chunked_query.split = true;
111 break;
112 }
113 if remaining == 0 {
114 break;
115 }
116 }
117
118 if cfg!(feature = "rfd9-compat") {
120 if chunked_query.key.is_empty() {
122 chunked_query.key.push("".to_owned());
123 }
124 if chunked_query.output.is_empty() {
126 chunked_query.output.push("".to_owned());
127 }
128 }
129
130 out.push(chunked_query);
131 }
132
133 if let Some(last) = out.last_mut() {
135 last.state = completion_state as i32;
136 }
137
138 if cfg!(feature = "rfd9-compat") && (null_key || null_output) {
142 if let Some(first) = out.first_mut() {
143 if null_key {
144 if let Some(k) = first.key.first_mut() {
145 *k = "null".to_owned()
146 }
147 }
148 if null_output {
149 if let Some(o) = first.output.first_mut() {
150 *o = "null".to_owned()
151 }
152 }
153 }
154 }
155
156 Ok(out)
157}
158
159pub fn chunk(msg: PluginQuery) -> Result<Vec<PluginQuery>> {
160 chunk_with_size(msg, GRPC_EFFECTIVE_MAX_SIZE)
161}
162
163pub fn prepare(msg: Query) -> Result<Vec<PluginQuery>> {
164 chunk(msg.try_into()?)
165}
166
167fn drain_vec_string(
174 source: &mut Vec<String>,
175 sink: &mut Vec<String>,
176 remaining: &mut usize,
177 made_progress: &mut bool,
178) -> bool {
179 while !source.is_empty() {
180 let s_to_drain = source.remove(0);
184 let drained_str = drain_at_most_n_bytes(s_to_drain, *remaining);
185 match drained_str {
186 DrainedString::CompleteString(complete) => {
187 *made_progress = true;
188 *remaining -= complete.len();
189 sink.push(complete);
190 }
191 DrainedString::PartialString {
192 drained_portion,
193 remainder,
194 } => {
195 let split = !drained_portion.is_empty();
197 if split {
198 *made_progress = true;
199 *remaining -= drained_portion.len();
200 sink.push(drained_portion);
201 }
202 source.insert(0, remainder);
205 return split;
206 }
207 }
208 }
209 false
210}
211
212fn in_progress_state(state: &QueryState) -> bool {
214 matches!(
215 state,
216 QueryState::ReplyInProgress | QueryState::SubmitInProgress
217 )
218}
219
220#[derive(Debug)]
222enum QueryVecField {
223 Key,
224 Output,
225 Concern,
226}
227
228fn last_field_to_have_content(query: &PluginQuery) -> QueryVecField {
235 if !query.concern.is_empty() {
236 return QueryVecField::Concern;
237 }
238 if cfg!(feature = "rfd9-compat") {
242 if !(query.output.len() == 1
243 && (query.output.first().unwrap() == "" || query.output.first().unwrap() == "null"))
244 {
245 return QueryVecField::Output;
246 }
247 } else if !query.output.is_empty() {
248 return QueryVecField::Output;
249 }
250 QueryVecField::Key
251}
252
253#[derive(Default)]
254pub struct QuerySynthesizer {
255 raw: Option<PluginQuery>,
256}
257
258impl QuerySynthesizer {
259 pub fn add<I>(&mut self, mut chunks: I) -> StdResult<Option<Query>, Error>
260 where
261 I: Iterator<Item = PluginQuery>,
262 {
263 if self.raw.is_none() {
264 self.raw = match chunks.next() {
265 Some(x) => Some(x),
266 None => {
267 return Ok(None);
268 }
269 };
270 }
271 let raw = self.raw.as_mut().unwrap(); let initial_state: QueryState = raw
274 .state
275 .try_into()
276 .map_err(|_| Error::UnspecifiedQueryState)?;
277 let mut current_state: QueryState = initial_state;
279
280 let mut last_message_split: Option<QueryVecField> = if raw.split {
283 Some(last_field_to_have_content(raw))
284 } else {
285 None
286 };
287
288 if in_progress_state(¤t_state) {
290 while in_progress_state(¤t_state) {
291 let mut next = match chunks.next() {
294 Some(msg) => msg,
295 None => {
296 return Ok(None);
297 }
298 };
299
300 current_state = next
302 .state
303 .try_into()
304 .map_err(|_| Error::UnspecifiedQueryState)?;
305 match (initial_state, current_state) {
306 (QueryState::Unspecified, _)
308 | (QueryState::ReplyComplete, _)
309 | (QueryState::SubmitComplete, _) => {
310 unreachable!()
311 }
312
313 (_, QueryState::Unspecified) => return Err(Error::UnspecifiedQueryState),
315 (QueryState::SubmitInProgress, QueryState::ReplyInProgress)
317 | (QueryState::SubmitInProgress, QueryState::ReplyComplete) => {
318 return Err(Error::ReceivedReplyWhenExpectingSubmitChunk)
319 }
320 (QueryState::ReplyInProgress, QueryState::SubmitInProgress)
322 | (QueryState::ReplyInProgress, QueryState::SubmitComplete) => {
323 return Err(Error::ReceivedSubmitWhenExpectingReplyChunk)
324 }
325 (_, _) => {
327 if current_state == QueryState::ReplyComplete {
328 raw.set_state(QueryState::ReplyComplete);
329 }
330 if current_state == QueryState::SubmitComplete {
331 raw.set_state(QueryState::SubmitComplete);
332 }
333
334 let next_message_split = if next.split {
335 Some(last_field_to_have_content(&next))
336 } else {
337 None
338 };
339
340 if let Some(split_field) = last_message_split {
348 match split_field {
349 QueryVecField::Key => {
350 raw.key
351 .last_mut()
352 .unwrap()
353 .push_str(next.key.remove(0).as_str());
354 }
355 QueryVecField::Output => {
356 raw.output
357 .last_mut()
358 .unwrap()
359 .push_str(next.output.remove(0).as_str());
360 }
361 QueryVecField::Concern => {
362 raw.concern
363 .last_mut()
364 .unwrap()
365 .push_str(next.concern.remove(0).as_str());
366 }
367 }
368 }
369
370 raw.key.extend(next.key);
371 raw.output.extend(next.output);
372 raw.concern.extend(next.concern);
373
374 last_message_split = next_message_split;
376 }
377 };
378 }
379
380 if chunks.next().is_some() {
382 return Err(Error::MoreAfterQueryComplete {
383 id: raw.id as usize,
384 });
385 }
386 }
387
388 self.raw.take().unwrap().try_into().map(Some)
389 }
390}
391
392#[cfg(test)]
393mod test {
394
395 use super::*;
396
397 #[test]
398 fn test_bounded_char_draining() {
399 let orig_key = "aこれは実験です".to_owned();
400 let max_size = 10;
401 let res = drain_at_most_n_bytes(orig_key.clone(), max_size);
402 let (drained, remainder) = match res {
403 DrainedString::CompleteString(_) => panic!("expected to return PartialString"),
404 DrainedString::PartialString {
405 drained_portion,
406 remainder,
407 } => (drained_portion, remainder),
408 };
409 assert!((0..=max_size).contains(&drained.len()));
410
411 let mut reassembled = drained;
413 reassembled.push_str(remainder.as_str());
414 assert_eq!(orig_key, reassembled);
415 }
416
417 #[test]
419 fn test_draining_vec() {
420 let mut source = vec!["123456".to_owned()];
421 let mut sink = vec![];
422
423 while !source.is_empty() {
425 let mut made_progress = false;
426 let partial = drain_vec_string(&mut source, &mut sink, &mut 1, &mut made_progress);
427 assert_eq!(partial, !source.is_empty())
429 }
430 assert_eq!(sink.len(), 6);
431 assert!(source.is_empty());
432
433 let mut source = vec!["123456".to_owned()];
435 let mut sink = vec![];
436 while !source.is_empty() {
437 let mut made_progress = false;
438 let partial = drain_vec_string(&mut source, &mut sink, &mut 3, &mut made_progress);
439 assert_eq!(partial, !source.is_empty())
441 }
442 assert_eq!(sink.len(), 2);
443 assert!(source.is_empty());
444 }
445
446 #[test]
448 fn test_char_boundary_respected() {
449 let mut source = vec!["実".to_owned()];
450 let mut sink = vec![];
451 let mut made_progress = false;
452 drain_vec_string(&mut source, &mut sink, &mut 1, &mut made_progress);
455 assert!(!made_progress);
456 }
457
458 #[test]
460 fn test_non_ascii_drain_vec_string_makes_progress() {
461 let mut source = vec!["1234".to_owned(), "aこれ".to_owned(), "abcdef".to_owned()];
462 let mut sink = vec![];
463
464 while !source.is_empty() {
465 let remaining = &mut 4;
467 let made_progress = &mut false;
468 drain_vec_string(&mut source, &mut sink, remaining, made_progress);
469 assert!(*made_progress);
470 }
471 assert_eq!(sink.first().unwrap(), "1234");
473 assert!(source.is_empty());
474 }
475
476 #[test]
477 fn test_drain_vec_string_split_detection() {
478 let mut max_len = 3;
481 let mut source = vec!["1234".to_owned()];
482 let mut sink = vec![];
483 let mut made_progress = false;
484 let split = drain_vec_string(&mut source, &mut sink, &mut max_len, &mut made_progress);
485 assert!(split);
486 assert_eq!(source, vec!["4"]);
487 assert!(made_progress);
488 assert_eq!(source.len(), 1);
489 assert_eq!(sink.len(), 1);
490
491 let mut max_len = 10;
494 let mut source = vec!["123456789".to_owned()];
495 let mut sink = vec![];
496 let mut made_progress = false;
497 let split = drain_vec_string(&mut source, &mut sink, &mut max_len, &mut made_progress);
498 assert!(!split);
499 assert!(source.is_empty());
500 assert!(made_progress);
501 assert_eq!(sink.len(), 1);
502 }
503
504 #[test]
505 fn test_chunking_and_query_reconstruction() {
506 let states = [
508 (QueryState::SubmitInProgress, QueryState::SubmitComplete),
509 (QueryState::ReplyInProgress, QueryState::ReplyComplete),
510 ];
511
512 for (intermediate_state, final_state) in states.into_iter() {
513 let output = if cfg!(feature = "rfd9-compat") {
514 vec!["null".to_owned()]
515 } else {
516 vec![]
517 };
518 let orig_query = PluginQuery {
519 id: 0,
520 state: final_state as i32,
521 publisher_name: "".to_owned(),
522 plugin_name: "".to_owned(),
523 query_name: "".to_owned(),
524 key: vec![serde_json::to_string("aこれは実験です").unwrap()],
526 output,
527 concern: vec![
528 "< 10".to_owned(),
529 "0123456789".to_owned(),
530 "< 10#2".to_owned(),
531 ],
532 split: false,
533 };
534 let res = match chunk_with_size(orig_query.clone(), 10) {
535 Ok(r) => r,
536 Err(e) => {
537 panic!("chunk_with_size unexpectedly errored: {e}");
538 }
539 };
540
541 res[..res.len() - 1]
543 .iter()
544 .for_each(|x| assert_eq!(x.state(), intermediate_state));
545 assert_eq!(res.last().unwrap().state(), final_state);
547
548 let mut synth = QuerySynthesizer::default();
550 let synthesized_query = synth.add(res.into_iter()).unwrap();
551
552 let synthesized_plugin_query: PluginQuery =
553 synthesized_query.unwrap().try_into().unwrap();
554 assert_eq!(orig_query, synthesized_plugin_query);
555 }
556 }
557}