use crate::chat::Format;
use crate::client::ModelClient;
use crate::client::handle_error_response;
use crate::client::json_lines_stream;
use crate::error::{OllamaError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio_stream::Stream;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GenerateRequest {
pub model: String,
pub prompt: String,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<Format>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub raw: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub think: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub height: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub steps: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateResponse {
pub model: String,
pub created_at: String,
pub response: String,
pub done: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub done_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<Vec<u32>>,
#[serde(default)]
pub total_duration: u64,
#[serde(default)]
pub load_duration: u64,
#[serde(default)]
pub prompt_eval_count: u32,
#[serde(default)]
pub prompt_eval_duration: u64,
#[serde(default)]
pub eval_count: u32,
#[serde(default)]
pub eval_duration: u64,
}
impl ModelClient {
pub async fn generate(&self, request: GenerateRequest) -> Result<GenerateResponse> {
let url = self
.base_url
.join("api/generate")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
pub async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<impl Stream<Item = Result<GenerateResponse>> + '_> {
let url = self
.base_url
.join("api/generate")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
if !response.status().is_success() {
return Err(handle_error_response(response, Some(&request.model)).await);
}
Ok(json_lines_stream(response))
}
}