use crate::{
LmcppServer,
client::types::completion::{
CompletionRequest, CompletionRequestBuilder, CompletionResponse, completion_request_builder,
},
error::LmcppResult,
server::ipc::ServerClientExt,
};
impl LmcppServer {
pub fn completion<A: CompletionRequestProvider>(
&self,
request: A,
) -> LmcppResult<CompletionResponse> {
request.with_request(|req| self.client.post("/completion", req).map_err(Into::into))
}
}
pub trait CompletionRequestProvider {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&CompletionRequest) -> R;
}
impl CompletionRequestProvider for CompletionRequest {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&CompletionRequest) -> R,
{
f(&self)
}
}
impl<'a> CompletionRequestProvider for &'a CompletionRequest {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&CompletionRequest) -> R,
{
f(self)
}
}
impl<'a> CompletionRequestProvider for &'a mut CompletionRequest {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&CompletionRequest) -> R,
{
f(&*self)
}
}
impl<S> CompletionRequestProvider for CompletionRequestBuilder<S>
where
S: completion_request_builder::IsComplete, {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&CompletionRequest) -> R,
{
let req: CompletionRequest = self.build();
f(&req)
}
}
impl<'a, S> CompletionRequestProvider for &'a CompletionRequestBuilder<S>
where
S: completion_request_builder::IsComplete,
{
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&CompletionRequest) -> R,
{
let req = self.clone().build();
f(&req)
}
}
impl<'a, S> CompletionRequestProvider for &'a mut CompletionRequestBuilder<S>
where
S: completion_request_builder::IsComplete,
{
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&CompletionRequest) -> R,
{
let req = self.clone().build();
f(&req)
}
}
#[cfg(test)]
mod tests {
use serial_test::serial;
use super::*;
use crate::{
LmcppServer, client::types::generation_settings::SamplingParams,
server::builder::LmcppServerLauncher,
};
#[test]
#[ignore]
#[serial]
fn test_lmcpp_server_completion() -> LmcppResult<()> {
let client = LmcppServerLauncher::default().load()?;
let response = client.completion(
CompletionRequest::builder()
.prompt("Hello, world!")
.n_predict(100),
)?;
println!("Completion response: {:#?}", response);
Ok(())
}
#[test]
#[ignore]
#[allow(unused_mut)]
fn test_lmcpp_server_completion_variants() -> LmcppResult<()> {
let client = LmcppServer::dummy();
let req_owned = CompletionRequest::builder()
.prompt("Test request")
.sampling(SamplingParams::builder().temperature(0.7).build())
.build();
let _ = client.completion(req_owned);
let mut req_owned = CompletionRequest::builder().prompt("Test request").build();
let _ = client.completion(req_owned);
let req_owned = CompletionRequest::builder().prompt("Test request").build();
let _ = client.completion(&req_owned);
let mut req_owned = CompletionRequest::builder().prompt("Test request").build();
let _ = client.completion(&mut req_owned);
let req_owned = CompletionRequest::builder().prompt("Test request");
let _ = client.completion(req_owned);
let mut req_owned = CompletionRequest::builder().prompt("Test request");
let _ = client.completion(req_owned);
let req_owned = CompletionRequest::builder().prompt("Test request");
let _ = client.completion(&req_owned);
let mut req_owned = CompletionRequest::builder().prompt("Test request");
let _ = client.completion(&mut req_owned);
Ok(())
}
}