sonic-channel 1.1.0

Rust client for sonic search backend
Documentation
use std::io::{self, BufWriter, Write};
use std::{path::PathBuf, str::FromStr};

use crate::{result::*, ChannelMode};

#[derive(Debug, Default)]
pub struct Protocol {
    #[allow(dead_code)]
    version: Version,
}

impl From<Version> for Protocol {
    fn from(version: Version) -> Self {
        Self { version }
    }
}

impl Protocol {
    pub fn format_request(&self, req: Request) -> io::Result<Vec<u8>> {
        let mut res = BufWriter::new(Vec::new());

        match req {
            Request::Quit => write!(res, "QUIT")?,

            Request::Ping => write!(res, "PING")?,

            Request::Start { mode, password } => write!(res, "START {} {}", mode, password)?,

            #[rustfmt::skip]
            Request::Count { collection, bucket, object } => match (bucket, object) {
                (Some(b), Some(o)) => write!(res, "COUNT {} {} {}", collection, b, o)?,
                (Some(b), None) => write!(res, "COUNT {} {}", collection, b)?,
                (None, None) => write!(res, "COUNT {}", collection)?,
                _ => panic!("Wrong protocol format"),
            },

            #[rustfmt::skip]
            Request::Flush { collection, bucket, object } => match (bucket, object) {
                (Some(b), Some(o)) => write!(res, "FLUSHO {} {} {}", collection, b, o)?,
                (Some(b), None) => write!(res, "FLUSHB {} {}", collection, b)?,
                (None, None) => write!(res, "FLUSHC {}", collection)?,
                _ => panic!("Wrong protocol format"),
            },

            #[rustfmt::skip]
            Request::Pop { collection, bucket, object, terms } => {
                write!(res, "POP {} {} {} \"{}\"", collection, bucket, object, terms)?
            },
            #[rustfmt::skip]
            Request::Push { collection, bucket, object, terms, lang } => {
                let oneline_terms = remove_multiline(&terms);
                write!(res, "PUSH {} {} {} \"{}\"", collection, bucket, object, oneline_terms)?;
                if let Some(lang) = lang {
                    write!(res, " LANG({})", lang)?
                }
            }

            #[rustfmt::skip]
            Request::Query { collection, bucket, terms, offset, limit, lang } => {
                write!(res, "QUERY {} {} \"{}\"", collection, bucket, terms)?;
                if let Some(limit) = limit {
                    write!(res, " LIMIT({})", limit)?;
                }
                if let Some(offset) = offset {
                    write!(res, " OFFSET({})", offset)?;
                }
                if let Some(lang) = lang {
                    write!(res, " LANG({})", lang)?;
                }
            }
            #[rustfmt::skip]
            Request::Suggest { collection, bucket, word, limit } => {
                write!(res, "SUGGEST {} {} \"{}\"", collection, bucket, word)?;
                if let Some(limit) = limit {
                    write!(res, " LIMIT({})", limit)?;
                }
            }

            #[rustfmt::skip]
            Request::List { collection, bucket, limit, offset } => {
                write!(res, "LIST {} {}", collection, bucket)?;
                if let Some(limit) = limit {
                    write!(res, " LIMIT({})", limit)?;
                }
                if let Some(offset) = offset {
                    write!(res, " OFFSET({})", offset)?;
                }
            }

            Request::Trigger(triger_req) => match triger_req {
                TriggerRequest::Consolidate => write!(res, "TRIGGER consolidate")?,
                TriggerRequest::Backup(path) => {
                    write!(res, "TRIGGER backup {}", path.to_str().unwrap())?
                }
                TriggerRequest::Restore(path) => {
                    write!(res, "TRIGGER restore {}", path.to_str().unwrap())?
                }
            },
        }

        write!(res, "\r\n")?;
        res.flush()?;

        Ok(res.into_inner()?)
    }

    pub fn parse_response(&self, line: &str) -> Result<Response> {
        let mut segments = line.split_whitespace();
        match segments.next() {
            Some("STARTED") => match (segments.next(), segments.next(), segments.next()) {
                (Some(_raw_mode), Some(raw_protocol), Some(raw_buffer_size)) => {
                    Ok(Response::Started(StartedPayload {
                        protocol_version: parse_server_config(raw_protocol)?,
                        max_buffer_size: parse_server_config(raw_buffer_size)?,
                    }))
                }
                _ => Err(Error::WrongResponse),
            },
            Some("PENDING") => {
                let event_id = segments
                    .next()
                    .map(String::from)
                    .ok_or(Error::WrongResponse)?;
                Ok(Response::Pending(event_id))
            }
            Some("RESULT") => match segments.next() {
                Some(num) => num
                    .parse()
                    .map(Response::Result)
                    .map_err(|_| Error::WrongResponse),
                _ => Err(Error::WrongResponse),
            },
            Some("EVENT") => {
                let event_kind = match segments.next() {
                    Some("SUGGEST") => Ok(EventKind::Suggest),
                    Some("QUERY") => Ok(EventKind::Query),
                    Some("LIST") => Ok(EventKind::List),
                    _ => Err(Error::WrongResponse),
                }?;

                let event_id = segments
                    .next()
                    .map(String::from)
                    .ok_or(Error::WrongResponse)?;

                let objects = segments.map(String::from).collect();

                Ok(Response::Event(event_kind, event_id, objects))
            }
            Some("OK") => Ok(Response::Ok),
            Some("ENDED") => Ok(Response::Ended),
            Some("CONNECTED") => Ok(Response::Connected),
            Some("ERR") => match segments.next() {
                Some(message) => Err(Error::SonicServer(String::from(message))),
                _ => Err(Error::WrongResponse),
            },
            _ => Err(Error::WrongResponse),
        }
    }
}

//===========================================================================//
// Primitives                                                                //
//===========================================================================//

#[derive(Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum Version {
    V1 = 1,
}

impl Default for Version {
    fn default() -> Self {
        Self::V1
    }
}

impl TryFrom<u8> for Version {
    type Error = ();

    fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
        match value {
            1 => Ok(Self::V1),
            _ => Err(()),
        }
    }
}

//===========================================================================//
// Response                                                                  //
//===========================================================================//

pub type EventId = String;

#[derive(Debug)]
pub enum Response {
    Ok,
    Ended,
    Connected,
    Pending(EventId),
    Pong,
    Started(StartedPayload),
    Result(usize),
    Event(EventKind, EventId, Vec<String>),
}

#[derive(Debug)]
pub struct StartedPayload {
    pub protocol_version: u8,
    pub max_buffer_size: usize,
}

#[derive(Debug)]
pub enum EventKind {
    Suggest,
    Query,
    List,
}

//===========================================================================//
// Request                                                                   //
//===========================================================================//

#[derive(Debug)]
pub enum Request {
    Start {
        mode: ChannelMode,
        password: String,
    },
    Quit,
    Ping,
    Trigger(TriggerRequest),
    Suggest {
        collection: String,
        bucket: String,
        word: String,
        limit: Option<usize>,
    },
    List {
        collection: String,
        bucket: String,
        limit: Option<usize>,
        offset: Option<usize>,
    },
    Query {
        collection: String,
        bucket: String,
        terms: String,
        offset: Option<usize>,
        limit: Option<usize>,
        lang: Option<&'static str>,
    },
    Push {
        collection: String,
        bucket: String,
        object: String,
        terms: String,
        lang: Option<&'static str>,
    },
    Pop {
        collection: String,
        bucket: String,
        object: String,
        terms: String,
    },
    Flush {
        collection: String,
        bucket: Option<String>,
        object: Option<String>,
    },
    Count {
        collection: String,
        bucket: Option<String>,
        object: Option<String>,
    },
}

#[derive(Debug)]
pub enum TriggerRequest {
    Consolidate,
    Backup(PathBuf),
    Restore(PathBuf),
}

//===========================================================================//
// Utils                                                                     //
//===========================================================================//

fn parse_server_config<T: FromStr>(raw: &str) -> Result<T> {
    raw.split_terminator(&['(', ')'])
        .nth(1)
        .ok_or(Error::WrongResponse)?
        .parse()
        .map_err(|_| Error::WrongResponse)
}

fn remove_multiline(text: &str) -> String {
    text.lines()
        .enumerate()
        .fold(String::new(), |mut acc, (i, line)| {
            if i != 0 && !line.is_empty() && !acc.is_empty() && !acc.ends_with(' ') {
                acc.push(' ');
            }

            acc.push_str(line);
            acc
        })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn should_parse_protocol() {
        match parse_server_config::<u8>("protocol(1)") {
            Ok(protocol) => assert_eq!(protocol, 1),
            _ => unreachable!(),
        }
    }

    #[test]
    fn should_parse_buffer_size() {
        match parse_server_config::<usize>("buffer_size(20000)") {
            Ok(buffer_size) => assert_eq!(buffer_size, 20000),
            _ => unreachable!(),
        }
    }

    #[test]
    fn should_make_single_line() {
        let text = "
Hello
World
";

        let expected_text = "Hello World";
        assert_eq!(remove_multiline(text), expected_text);
    }
}