use crate::{client::WitClient, errors::Error};
use reqwest::Method;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug)]
pub struct Context {
serialized: String,
}
#[derive(Debug, Serialize)]
pub struct ContextBuilder {
reference_time: Option<String>,
timezone: Option<String>,
locale: Option<String>,
coords: Option<Coordinates>,
}
impl ContextBuilder {
pub fn new() -> Self {
Self {
reference_time: None,
timezone: None,
locale: None,
coords: None,
}
}
pub fn reference_time(mut self, reference_time: String) -> Self {
self.reference_time = Some(reference_time);
self
}
pub fn timezone(mut self, timezone: String) -> Self {
self.timezone = Some(timezone);
self
}
pub fn locale(mut self, value: String) -> Self {
self.locale = Some(value);
self
}
pub fn coords(mut self, coords: Coordinates) -> Self {
self.coords = Some(coords);
self
}
pub fn build(self) -> Context {
let serialized =
serde_json::to_string(&self).expect("should be able to serialize `Context` struct");
Context { serialized }
}
}
impl Default for ContextBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Serialize)]
pub struct Coordinates {
lat: f64,
long: f64,
}
impl Coordinates {
pub fn new(latitude: f64, longitude: f64) -> Self {
Self {
lat: latitude,
long: longitude,
}
}
}
#[derive(Debug)]
pub struct MessageRequest {
url_params: Vec<(String, String)>,
}
#[derive(Debug)]
pub struct MessageRequestBuilder {
query: String,
tag: Option<String>,
n: Option<u16>,
context: Option<Context>,
}
impl MessageRequestBuilder {
pub fn new(query: String) -> Self {
MessageRequestBuilder {
query,
tag: None,
n: None,
context: None,
}
}
pub fn tag(mut self, tag: String) -> Self {
self.tag = Some(tag);
self
}
pub fn limit(mut self, limit: u16) -> Result<Self, Error> {
if !(1..=8).contains(&limit) {
return Err(Error::InvalidArgument(format!(
"limit should be between 1 and 8 inclusive, got {limit}"
)));
}
self.n = Some(limit);
Ok(self)
}
pub fn context(mut self, context: Context) -> Self {
self.context = Some(context);
self
}
pub fn build(self) -> MessageRequest {
let mut url_params = Vec::new();
url_params.push((String::from("q"), self.query));
if let Some(tag) = self.tag {
url_params.push((String::from("tag"), tag));
}
if let Some(n) = self.n {
url_params.push((String::from("n"), n.to_string()));
}
if let Some(context) = self.context {
url_params.push((String::from("context"), context.serialized));
}
MessageRequest { url_params }
}
}
#[derive(Debug, Deserialize, PartialEq)]
pub struct MessageResponse {
pub text: String,
pub intents: Vec<MessageIntent>,
pub entities: HashMap<String, Vec<MessageEntity>>,
pub traits: HashMap<String, Vec<MessageTrait>>,
}
#[derive(Debug, Deserialize, PartialEq)]
pub struct MessageIntent {
pub id: String,
pub name: String,
pub confidence: f64,
}
#[derive(Debug, Deserialize, PartialEq)]
pub struct MessageEntity {
pub id: String,
pub name: String,
pub role: String,
pub start: u32,
pub end: u32,
pub body: String,
pub confidence: f64,
pub entities: HashMap<String, MessageEntity>,
pub value: Option<Value>,
pub from: Option<IntervalEndpoint>,
pub to: Option<IntervalEndpoint>,
}
#[derive(Debug, Deserialize, PartialEq)]
pub struct IntervalEndpoint {
pub unit: Option<String>,
pub grain: Option<String>,
pub value: Value,
}
#[derive(Debug, Deserialize, PartialEq)]
pub struct MessageTrait {
pub id: String,
pub value: Value,
pub confidence: f64,
}
impl WitClient {
pub async fn message(&self, request: MessageRequest) -> Result<MessageResponse, Error> {
self.make_request(
Method::GET,
"/message",
request.url_params,
Option::<Value>::None,
)
.await
}
pub async fn message_simple(&self, query: String) -> Result<MessageResponse, Error> {
let request = MessageRequestBuilder::new(query).build();
self.message(request).await
}
}