use bon::Builder;
use serde::{Deserialize, Serialize};
use crate::{
LmcppServer,
client::types::completion::{CompletionRequest, CompletionResponse},
server::ipc::{ClientError, ServerClientExt},
};
impl LmcppServer {
pub fn infill<A: InfillRequestProvider>(
&self,
request: A,
) -> Result<CompletionResponse, ClientError> {
request.with_request(|req| self.client.post("/infill", req))
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
#[builder(on(String, into))]
#[builder(derive(Debug, Clone))]
pub struct InfillRequest {
pub input_prefix: String,
pub input_suffix: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_extra: Option<Vec<InputExtra>>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub completion: Option<CompletionRequest>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct InputExtra {
pub filename: String,
pub text: String,
}
pub trait InfillRequestProvider {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&InfillRequest) -> R;
}
impl InfillRequestProvider for InfillRequest {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&InfillRequest) -> R,
{
f(&self)
}
}
impl<'a> InfillRequestProvider for &'a InfillRequest {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&InfillRequest) -> R,
{
f(self)
}
}
impl<'a> InfillRequestProvider for &'a mut InfillRequest {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&InfillRequest) -> R,
{
f(&*self)
}
}
impl<S> InfillRequestProvider for InfillRequestBuilder<S>
where
S: infill_request_builder::IsComplete, {
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&InfillRequest) -> R,
{
let req: InfillRequest = self.build();
f(&req)
}
}
impl<'a, S> InfillRequestProvider for &'a InfillRequestBuilder<S>
where
S: infill_request_builder::IsComplete,
{
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&InfillRequest) -> R,
{
let req = self.clone().build();
f(&req)
}
}
impl<'a, S> InfillRequestProvider for &'a mut InfillRequestBuilder<S>
where
S: infill_request_builder::IsComplete,
{
fn with_request<F, R>(self, f: F) -> R
where
F: FnOnce(&InfillRequest) -> R,
{
let req = self.clone().build();
f(&req)
}
}
#[cfg(test)]
mod tests {
use serial_test::serial;
use super::*;
use crate::{
LmcppServer,
error::LmcppResult,
server::{
builder::LmcppServerLauncher, toolchain::builder::LmcppToolChain,
types::start_args::ServerArgs,
},
};
#[test]
#[ignore]
#[serial]
fn test_lmcpp_server_infill() -> LmcppResult<()> {
let client = LmcppServerLauncher::builder()
.toolchain(LmcppToolChain::builder().install_only().build()?)
.server_args(
ServerArgs::builder()
.hf_repo("bartowski/codegemma-2b-GGUF")?
.build(),
)
.load()?;
let response = client.infill(
InfillRequest::builder()
.input_prefix("Hello, ")
.input_suffix(" world!")
.completion(CompletionRequest::builder().prompt("Hello, world!").build())
.build(),
)?;
println!("Infill response: {:#?}", response);
Ok(())
}
#[test]
#[ignore]
#[allow(unused_mut)]
fn test_lmcpp_server_infill_variants() -> LmcppResult<()> {
let client = LmcppServer::dummy();
let req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build())
.build();
let _ = client.infill(req_owned);
let mut req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build())
.build();
let _ = client.infill(req_owned);
let req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build())
.build();
let _ = client.infill(&req_owned);
let mut req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build())
.build();
let _ = client.infill(&mut req_owned);
let req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build());
let _ = client.infill(req_owned);
let mut req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build());
let _ = client.infill(req_owned);
let req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build());
let _ = client.infill(&req_owned);
let mut req_owned = InfillRequest::builder()
.input_prefix("hi")
.input_suffix("ho")
.completion(CompletionRequest::builder().prompt("Test request").build());
let _ = client.infill(&mut req_owned);
Ok(())
}
}