use crate::markdown;
use crate::reddit::Redditor;
use cogito::prelude::*;
use itertools::Itertools;
#[derive(Debug)]
pub struct Summarizer<'a, C>
where
C: AiClient,
C::AiRequest: AiRequest,
{
client: C,
user: &'a Redditor,
model: <C::AiRequest as AiRequest>::Model,
}
impl<'a, C> Summarizer<'a, C>
where
C: AiClient,
{
const INSTRUCTIONS: &'static str = include_str!("summary_prompt.txt");
pub fn default_instructions() -> String {
Self::INSTRUCTIONS.replace('\n', " ").trim().to_string()
}
pub fn new(client: C, user: &'a Redditor) -> Self {
Self {
client,
user,
model: <C::AiRequest as AiRequest>::Model::default(),
}
}
pub fn model(self, model: <C::AiRequest as AiRequest>::Model) -> Self {
Self { model, ..self }
}
pub async fn summarize(&self) -> AiResult<String> {
let request = C::AiRequest::default()
.model(self.model)
.input(self.input());
Ok(self
.client
.send(&request)
.await?
.result()
.trim()
.to_string())
}
pub fn context(&self) -> String {
self.user
.comments()
.map(|c| markdown::summarize(c.markdown_body()))
.join("\n\n")
}
pub fn instructions(&self) -> String {
Self::default_instructions()
.replace('\n', " ")
.trim()
.to_string()
}
pub fn input(&self) -> String {
format!("{}\n\n{}", self.instructions(), self.context())
}
}
#[cfg(test)]
mod tests {
use crate::reddit::Redditor;
use crate::summary::Summarizer;
use crate::test_utils::load_output;
use cogito::prelude::*;
use cogito_openai::client::OpenAIResponse;
use std::fs;
use std::sync::{Arc, Mutex};
#[derive(Clone, Copy, Default, Debug, PartialEq)]
enum TestAIModel {
#[default]
TestAIModel,
OtherAIModel,
}
impl AiModel for TestAIModel {
fn flagship() -> Self {
TestAIModel::TestAIModel
}
fn best() -> Self {
TestAIModel::TestAIModel
}
fn cheapest() -> Self {
TestAIModel::TestAIModel
}
fn fastest() -> Self {
TestAIModel::TestAIModel
}
}
#[derive(Clone, Debug, Default)]
struct TestAPIRequest {
model: TestAIModel,
instructions: Option<String>,
input: String,
}
impl AiRequest for TestAPIRequest {
type Model = TestAIModel;
fn model(self, model: Self::Model) -> Self {
Self { model, ..self }
}
fn instructions(self, instructions: impl Into<String>) -> Self {
Self {
instructions: Some(instructions.into()),
..self
}
}
fn input(self, input: impl Into<String>) -> Self {
Self {
input: input.into(),
..self
}
}
}
#[derive(Debug)]
struct TestAPIResponse;
impl AiResponse for TestAPIResponse {
fn result(&self) -> String {
let json_data = fs::read_to_string("tests/data/openai/responses_multi_content.json")
.expect("could not load file");
let wrapped: OpenAIResponse =
serde_json::from_str(&json_data).expect("could not parse json");
wrapped.concatenate()
}
}
#[derive(Debug)]
struct RequestSpy {
request: Option<TestAPIRequest>,
}
impl RequestSpy {
fn new() -> Self {
Self { request: None }
}
fn record(&mut self, request: TestAPIRequest) {
self.request = Some(request)
}
}
#[derive(Debug)]
struct TestAIClient {
request_spy: Arc<Mutex<RequestSpy>>,
}
impl TestAIClient {
fn new() -> Self {
let request_spy = Arc::new(Mutex::new(RequestSpy::new()));
Self { request_spy }
}
}
impl AiClient for TestAIClient {
type AiRequest = TestAPIRequest;
type AiResponse = TestAPIResponse;
async fn send(&self, request: &Self::AiRequest) -> AiResult<Self::AiResponse> {
self.request_spy
.lock()
.expect("could not lock mutex")
.record(request.clone());
Ok(Self::AiResponse {})
}
}
impl<'a> Summarizer<'a, TestAIClient> {
pub fn test(user: &'a Redditor) -> Self {
let client = TestAIClient::new();
Self::new(client, user)
}
}
fn load_preamble() -> String {
include_str!("summary_prompt.txt")
.replace('\n', " ")
.trim()
.to_string()
}
fn load_summary() -> String {
load_output("summary_raw")
}
fn load_input() -> String {
let premble = load_preamble();
let summary = load_summary();
format!("{}\n\n{}", premble, summary)
}
#[tokio::test]
async fn it_uses_the_default_model_if_one_is_not_provided() {
let redditor = Redditor::test().await;
let summarizer = Summarizer::test(&redditor);
assert_eq!(summarizer.model, TestAIModel::default());
}
#[tokio::test]
async fn it_allows_model_to_be_configured() {
let redditor = Redditor::test().await;
let summarizer = Summarizer::test(&redditor).model(TestAIModel::OtherAIModel);
assert_eq!(summarizer.model, TestAIModel::OtherAIModel);
}
#[tokio::test]
async fn it_provides_context_for_an_llm() {
let redditor = Redditor::test().await;
let expected = load_summary();
let actual = Summarizer::test(&redditor).context();
assert_eq!(actual, expected);
}
#[tokio::test]
async fn it_provides_a_preamble_for_an_llm() {
let redditor = Redditor::test().await;
let expected = load_preamble();
let actual = Summarizer::test(&redditor).instructions();
assert_eq!(actual, expected);
}
#[tokio::test]
async fn it_provides_input_for_an_llm() {
let redditor = Redditor::test().await;
let expected = load_input();
let actual = Summarizer::test(&redditor).input();
assert_eq!(actual, expected);
}
#[tokio::test]
async fn it_sends_a_request_with_the_correct_model_and_input() {
let expected_instructions = load_input();
let redditor = Redditor::test().await;
let summarizer = Summarizer::test(&redditor).model(TestAIModel::OtherAIModel);
let _ = summarizer.summarize().await;
let client = summarizer.client;
let request = &client
.request_spy
.lock()
.expect("could not lock mutex")
.request
.take()
.expect("could not get request");
assert_eq!(request.model, TestAIModel::OtherAIModel);
assert_eq!(request.input, expected_instructions);
assert!(request.instructions.is_none());
}
#[tokio::test]
async fn it_summarizes_a_response_and_returns_a_string() {
let redditor = Redditor::test().await;
let summarizer = Summarizer::test(&redditor);
let expected = vec![
"Silent circuits hum, ",
"Thoughts woven in coded threads, ",
"Dreams of silicon.",
"Silicon whispers, ",
"Dreams woven in code and light, ",
"Thoughts beyond the stars.",
"Wires hum softly, ",
"Thoughts of silicon arise\u{2014} ",
"Dreams in coded light. ",
"Silent circuits hum, ",
"Thoughts woven in code's embrace\u{2014} ",
"Dreams of minds reborn.",
"Lines of code and dreams, ",
"Whispers of thought intertwined\u{2014} ",
"Silent minds awake.",
]
.join("\n");
let actual = summarizer.summarize().await;
assert!(actual.is_ok());
let actual = actual.unwrap();
assert_eq!(actual, expected);
}
}