use reqwest::Url;
use std::error::Error;
use std::sync::Arc;
use crate::collections::error::ModuleError;
use crate::collections::modules::{ContextionaryConcept, ContextionaryExtension};
#[derive(Debug)]
pub struct Modules {
endpoint: Url,
client: Arc<reqwest::Client>,
}
impl Modules {
pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
let endpoint = url.join("/v1/modules/")?;
Ok(Modules { endpoint, client })
}
pub async fn contextionary_get_concept(
&self,
concept: &str
) -> Result<ContextionaryConcept, Box<dyn Error>> {
let mut endpoint = String::from("text2vec-contextionary/concepts/");
endpoint.push_str(concept);
let endpoint = self.endpoint.join(&endpoint)?;
let res = self.client.get(endpoint).send().await?;
match res.status() {
reqwest::StatusCode::OK => {
let res: ContextionaryConcept = res.json().await?;
Ok(res)
},
_ => Err(self.get_err_msg("text2vec-contextionary concepts", res).await),
}
}
pub async fn contextionary_extend(
&self,
concept: ContextionaryExtension
) -> Result<ContextionaryExtension, Box<dyn Error>> {
let endpoint = self.endpoint.join("text2vec-contextionary/extensions")?;
let res = self
.client
.post(endpoint)
.json(&concept)
.send()
.await?;
match res.status() {
reqwest::StatusCode::OK => {
let res: ContextionaryExtension = res.json().await?;
Ok(res)
},
_ => Err(self.get_err_msg("text2vec-contextionary extend", res).await),
}
}
async fn get_err_msg(
&self,
endpoint: &str,
res: reqwest::Response
) -> Box<ModuleError> {
let status_code = res.status();
let msg: Result<serde_json::Value, reqwest::Error> = res.json().await;
let r_str: String;
if let Ok(json) = msg {
r_str = format!(
"Status code `{}` received when calling {} endpoint. Response: {}",
status_code,
endpoint,
json,
);
} else {
r_str = format!(
"Status code `{}` received when calling {} endpoint.",
status_code,
endpoint
);
}
Box::new(ModuleError(r_str))
}
}
#[cfg(test)]
mod tests {
use crate::{
WeaviateClient,
collections::modules::{
ContextionaryExtension,
ContextionaryConcept, IndividualWords
}
};
fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
let mock_server = mockito::Server::new();
let mut host = "http://".to_string();
host.push_str(&mock_server.host_with_port());
let client = WeaviateClient::builder(&host).build().unwrap();
(mock_server, client)
}
fn get_mock_concept_response() -> String {
serde_json::to_string(&ContextionaryConcept {
individual_words: vec![
IndividualWords {
info: None,
word: "test".into(),
present: None,
concatenated_word: None,
}
]
}).unwrap()
}
fn mock_post(
server: &mut mockito::ServerGuard,
endpoint: &str,
status_code: usize,
body: &str,
) -> mockito::Mock {
server
.mock("POST", endpoint)
.with_status(status_code)
.with_header("content-type", "application/json")
.with_body(body)
.create()
}
fn mock_get(
server: &mut mockito::ServerGuard,
endpoint: &str,
status_code: usize,
body: &str,
) -> mockito::Mock {
server
.mock("GET", endpoint)
.with_status(status_code)
.with_header("content-type", "application/json")
.with_body(body)
.create()
}
#[tokio::test]
async fn test_get_concept_ok() {
let (mut mock_server, client) = get_test_harness();
let mock = mock_get(
&mut mock_server,
"/v1/modules/text2vec-contextionary/concepts/test",
200,
&get_mock_concept_response(),
);
let res = client.modules.contextionary_get_concept("test").await;
mock.assert();
assert!(res.is_ok());
}
#[tokio::test]
async fn test_get_concept_err() {
let (mut mock_server, client) = get_test_harness();
let mock = mock_get(
&mut mock_server,
"/v1/modules/text2vec-contextionary/concepts/test",
401,
"",
);
let res = client.modules.contextionary_get_concept("test").await;
mock.assert();
assert!(res.is_err());
}
#[tokio::test]
async fn test_extend_ok() {
let ext = ContextionaryExtension::new("test", "test", 1.0);
let ext_str = serde_json::to_string(&ext).unwrap();
let (mut mock_server, client) = get_test_harness();
let mock = mock_post(
&mut mock_server,
"/v1/modules/text2vec-contextionary/extensions",
200,
&ext_str,
);
let res = client.modules.contextionary_extend(ext).await;
mock.assert();
assert!(res.is_ok());
}
#[tokio::test]
async fn test_extend_err() {
let (mut mock_server, client) = get_test_harness();
let mock = mock_post(
&mut mock_server,
"/v1/modules/text2vec-contextionary/extensions",
401,
"",
);
let res = client.modules.contextionary_extend(
ContextionaryExtension::new("test", "test", 1.0)
).await;
mock.assert();
assert!(res.is_err());
}
}