sollama 0.1.1

A CLI Tool to Search and summarize the results with Ollama models in your terminal
Documentation
use crate::{config::LLMConfig, Result, ScraperError};
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
use serde_json::json;
use std::time::Duration;
use tracing::{debug, instrument};

/// The `LLMProcessor` struct is responsible for processing prompts using a Language Model (LLM).
/// It handles the configuration, HTTP client setup, and the processing of prompts to generate responses.
pub struct LLMProcessor {
    /// The HTTP client used to send requests to the LLM endpoint.
    client: Client,
    /// The configuration for the LLM, including endpoint, temperature, and max tokens.
    config: LLMConfig,
}

/// The `ProcessedResponse` struct holds the details of the response generated by the LLM.
#[derive(Debug)]
pub struct ProcessedResponse {
    /// The content of the response generated by the LLM.
    pub content: String,
    /// The number of tokens in the response.
    pub token_count: usize,
    /// The time taken to process the response.
    pub processing_time: Duration,
    /// The model used to generate the response.
    pub model: String,
}

impl LLMProcessor {
    /// Creates a new `LLMProcessor` with the given configuration.
    ///
    /// # Arguments
    ///
    /// * `config` - The configuration for the LLM.
    ///
    /// # Returns
    ///
    /// A new instance of `LLMProcessor`.
    pub fn new(config: LLMConfig) -> Self {
        Self {
            client: Client::new(),
            config,
        }
    }

    /// Creates a progress bar with a spinner style and a custom message.
    ///
    /// # Arguments
    ///
    /// * `msg` - The message to display on the progress bar.
    ///
    /// # Returns
    ///
    /// A `ProgressBar` instance with the specified message.
    fn create_progress_bar(&self, msg: &str) -> ProgressBar {
        let spinner = ProgressBar::new_spinner();
        spinner.set_style(
            ProgressStyle::default_spinner()
                .template("{spinner:.green} {msg}")
                .unwrap(),
        );
        spinner.enable_steady_tick(Duration::from_millis(120));
        spinner.set_message(msg.to_string());
        spinner
    }

    /// Processes a prompt using the LLM and returns the generated response as a string.
    ///
    /// # Arguments
    ///
    /// * `prompt` - The prompt to be processed by the LLM.
    /// * `model` - The model to be used for processing the prompt.
    ///
    /// # Returns
    ///
    /// A `Result` containing the generated response as a string, or an error if the processing fails.
    #[instrument(skip(self, prompt), fields(prompt_length = prompt.len()))]
    pub async fn process(&self, prompt: &str, model: &str) -> Result<String> {
        let response = self.process_with_details(prompt, model).await?;
        Ok(response.content)
    }

    /// Processes a prompt using the LLM and returns detailed information about the response.
    ///
    /// # Arguments
    ///
    /// * `prompt` - The prompt to be processed by the LLM.
    /// * `model` - The model to be used for processing the prompt.
    ///
    /// # Returns
    ///
    /// A `Result` containing a `ProcessedResponse` with detailed information about the response, or an error if the processing fails.
    pub async fn process_with_details(&self, prompt: &str, model: &str) -> Result<ProcessedResponse> {
        debug!("Processing LLM request with prompt: {}", prompt);
        let spinner = self.create_progress_bar("Preparing LLM request...");
        let start_time = std::time::Instant::now();

        let request = json!({
            "system" : String::from(
                "You are a helpful assistant that analyzes text content to answer questions. \
                you will receive a lot of content and a statement or a query, Your responses should be \
                about the question or query or statement that was given as a prompt and nothing more :\n\
                1. Make your reply Accurate and based on the provided content\n\
                2. Well-structured and easy to understand\n\
                3. Directly addressing the original question or prompt\n\
                4. Including relevant citations when appropriate"
            ),
            "model": model,
            "prompt": prompt,
            "temperature": self.config.temperature,
            "max_tokens": self.config.max_tokens,
            "stream": false
        });

        // info!("Sending request to LLM model: {}", request.to_string());

        // Request phase
        spinner.set_message(format!("Sending request to {}...", model));
        let response = match self.client
            .post(&self.config.endpoint)
            .json(&request)
            .send()
            .await
        {
            Ok(resp) => resp,
            Err(e) => {
                spinner.finish_with_message("❌ LLM request failed!");
                return Err(ScraperError::LLMError(e.to_string()));
            }
        };

        // Processing phase
        spinner.set_message("Processing response...");
        let result: serde_json::Value = match response.json().await {
            Ok(json) => json,
            Err(e) => {
                spinner.finish_with_message("❌ Failed to parse LLM response!");
                return Err(ScraperError::LLMError(e.to_string()));
            }
        };

        let response_text = result["response"]
            .as_str()
            .map(String::from)
            .ok_or_else(|| {
                spinner.finish_with_message("❌ Invalid LLM response format!");
                ScraperError::LLMError("Invalid LLM response format".to_string())
            })?;

        let processing_time = start_time.elapsed();
        let token_estimate = response_text.split_whitespace().count();

        spinner.finish_with_message(format!(
            "✨ Generated response (~{} tokens) in {:.2?}",
            token_estimate,
            processing_time
        ));

        Ok(ProcessedResponse {
            content: response_text,
            token_count: token_estimate,
            processing_time,
            model: model.to_string(),
        })
    }
}