use std::collections::BTreeMap;
use std::fmt;
use reqwest::Method;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::config::ServiceBase;
use crate::{Client, Result};
#[derive(Clone)]
pub struct ModelsClient {
client: Client,
}
impl ModelsClient {
pub(crate) fn new(client: Client) -> Self {
Self { client }
}
pub async fn list(&self) -> Result<Vec<Model>> {
let response: ModelsListResponse = self
.client
.request(ServiceBase::Chat, Method::GET, "/models")
.send_json()
.await?;
Ok(response.data)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Model {
pub id: ModelId,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ModelId(String);
impl ModelId {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
pub fn core() -> Self {
Self::new("reka-core")
}
pub fn flash() -> Self {
Self::new("reka-flash")
}
pub fn edge() -> Self {
Self::new("reka-edge")
}
pub fn spark() -> Self {
Self::new("reka-spark")
}
pub fn flash_research() -> Self {
Self::new("reka-flash-research")
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&str> for ModelId {
fn from(value: &str) -> Self {
Self(value.to_owned())
}
}
impl From<String> for ModelId {
fn from(value: String) -> Self {
Self(value)
}
}
impl AsRef<str> for ModelId {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl fmt::Display for ModelId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
struct ModelsListResponse {
data: Vec<Model>,
#[serde(flatten)]
extra: BTreeMap<String, Value>,
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{Model, ModelId, ModelsListResponse};
#[test]
fn known_model_helpers_return_expected_ids() {
assert_eq!(ModelId::core().as_str(), "reka-core");
assert_eq!(ModelId::flash().as_str(), "reka-flash");
assert_eq!(ModelId::edge().as_str(), "reka-edge");
assert_eq!(ModelId::spark().as_str(), "reka-spark");
assert_eq!(ModelId::flash_research().as_str(), "reka-flash-research");
}
#[test]
fn model_preserves_unknown_fields() {
let model: Model = serde_json::from_value(json!({
"id": "reka-flash",
"release": "2603",
"capabilities": ["chat", "research"]
}))
.expect("model should deserialize");
assert_eq!(model.id, ModelId::from("reka-flash"));
assert_eq!(model.id.as_str(), "reka-flash");
assert_eq!(model.extra["release"], "2603");
assert_eq!(model.extra["capabilities"][0], "chat");
}
#[test]
fn models_response_deserializes_live_envelope() {
let response: ModelsListResponse = serde_json::from_value(json!({
"data": [
{ "id": "reka-flash-3", "name": "Reka Flash 3" }
],
"object": "list"
}))
.expect("models response should deserialize");
assert_eq!(response.data.len(), 1);
assert_eq!(response.data[0].id, ModelId::from("reka-flash-3"));
assert_eq!(response.data[0].extra["name"], "Reka Flash 3");
assert_eq!(response.extra["object"], "list");
}
}