use super::{
Error, JsonSnafu, ModelListParams, Schema, SchemaAllowCreate, SchemaAllowEdit, SchemaType,
SchemaView, SqlxSnafu, Status, format_datetime,
};
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use sqlx::FromRow;
use sqlx::{Pool, Postgres, QueryBuilder};
use std::collections::HashMap;
use tibba_model::Model;
use time::PrimitiveDateTime;
type Result<T> = std::result::Result<T, Error>;
#[derive(FromRow)]
struct TokenPriceSchema {
id: i64,
service: String,
model: String,
input_price: i64,
output_price: i64,
fixed_price: i64,
unit_size: i32,
status: i16,
remark: String,
created: PrimitiveDateTime,
modified: PrimitiveDateTime,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TokenPrice {
pub id: i64,
pub service: String,
pub model: String,
pub input_price: i64,
pub output_price: i64,
pub fixed_price: i64,
pub unit_size: i32,
pub status: i16,
pub remark: String,
pub created: String,
pub modified: String,
}
impl From<TokenPriceSchema> for TokenPrice {
fn from(s: TokenPriceSchema) -> Self {
Self {
id: s.id,
service: s.service,
model: s.model,
input_price: s.input_price,
output_price: s.output_price,
fixed_price: s.fixed_price,
unit_size: s.unit_size,
status: s.status,
remark: s.remark,
created: format_datetime(s.created),
modified: format_datetime(s.modified),
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TokenPriceInsertParams {
pub service: String,
pub model: Option<String>,
pub input_price: i64,
pub output_price: i64,
pub fixed_price: Option<i64>,
pub unit_size: Option<i32>,
pub status: Option<i16>,
pub remark: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct TokenPriceUpdateParams {
pub input_price: Option<i64>,
pub output_price: Option<i64>,
pub fixed_price: Option<i64>,
pub unit_size: Option<i32>,
pub status: Option<i16>,
pub remark: Option<String>,
}
#[derive(Default)]
pub struct TokenPriceModel {}
impl TokenPriceModel {
pub async fn get_by_service_model(
&self,
pool: &Pool<Postgres>,
service: &str,
model: &str,
) -> Result<Option<TokenPrice>> {
let result = sqlx::query_as::<_, TokenPriceSchema>(
r#"SELECT * FROM token_prices
WHERE service = $1 AND model = $2 AND status = 1 AND deleted_at IS NULL
LIMIT 1"#,
)
.bind(service)
.bind(model)
.fetch_optional(pool)
.await
.context(SqlxSnafu)?;
if result.is_some() {
return Ok(result.map(Into::into));
}
if !model.is_empty() {
let fallback = sqlx::query_as::<_, TokenPriceSchema>(
r#"SELECT * FROM token_prices
WHERE service = $1 AND model = '' AND status = 1 AND deleted_at IS NULL
LIMIT 1"#,
)
.bind(service)
.fetch_optional(pool)
.await
.context(SqlxSnafu)?;
return Ok(fallback.map(Into::into));
}
Ok(None)
}
pub fn calculate_cost(price: &TokenPrice, input_tokens: i32, output_tokens: i32) -> i64 {
let unit = price.unit_size.max(1) as i64;
let input_cost = (input_tokens as i64 * price.input_price + unit - 1) / unit;
let output_cost = (output_tokens as i64 * price.output_price + unit - 1) / unit;
price.fixed_price + input_cost + output_cost
}
}
impl Model for TokenPriceModel {
type Output = TokenPrice;
fn new() -> Self {
Self::default()
}
async fn schema_view(&self, _pool: &Pool<Postgres>) -> SchemaView {
SchemaView {
schemas: vec![
Schema::new_id(),
Schema {
name: "service".to_string(),
category: SchemaType::String,
required: true,
fixed: true,
filterable: true,
..Default::default()
},
Schema {
name: "model".to_string(),
category: SchemaType::String,
fixed: true,
filterable: true,
..Default::default()
},
Schema {
name: "input_price".to_string(),
category: SchemaType::Number,
required: true,
..Default::default()
},
Schema {
name: "output_price".to_string(),
category: SchemaType::Number,
required: true,
..Default::default()
},
Schema {
name: "fixed_price".to_string(),
category: SchemaType::Number,
..Default::default()
},
Schema {
name: "unit_size".to_string(),
category: SchemaType::Number,
default_value: Some(serde_json::json!(1000)),
..Default::default()
},
Schema::new_status(),
Schema::new_remark(),
Schema::new_created(),
Schema::new_modified(),
],
allow_edit: SchemaAllowEdit {
roles: vec!["su".to_string()],
..Default::default()
},
allow_create: SchemaAllowCreate {
roles: vec!["su".to_string()],
..Default::default()
},
}
}
async fn insert(&self, pool: &Pool<Postgres>, data: serde_json::Value) -> Result<u64> {
let p: TokenPriceInsertParams = serde_json::from_value(data).context(JsonSnafu)?;
let row: (i64,) = sqlx::query_as(
r#"INSERT INTO token_prices
(service, model, input_price, output_price, fixed_price, unit_size, status, remark)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id"#,
)
.bind(&p.service)
.bind(p.model.unwrap_or_default())
.bind(p.input_price)
.bind(p.output_price)
.bind(p.fixed_price.unwrap_or(0))
.bind(p.unit_size.unwrap_or(1000))
.bind(p.status.unwrap_or(Status::Enabled as i16))
.bind(p.remark.unwrap_or_default())
.fetch_one(pool)
.await
.context(SqlxSnafu)?;
Ok(row.0 as u64)
}
async fn get_by_id(&self, pool: &Pool<Postgres>, id: u64) -> Result<Option<Self::Output>> {
let result = sqlx::query_as::<_, TokenPriceSchema>(
r#"SELECT * FROM token_prices WHERE id = $1 AND deleted_at IS NULL"#,
)
.bind(id as i64)
.fetch_optional(pool)
.await
.context(SqlxSnafu)?;
Ok(result.map(Into::into))
}
async fn update_by_id(
&self,
pool: &Pool<Postgres>,
id: u64,
data: serde_json::Value,
) -> Result<()> {
let p: TokenPriceUpdateParams = serde_json::from_value(data).context(JsonSnafu)?;
let mut qb: QueryBuilder<Postgres> =
QueryBuilder::new("UPDATE token_prices SET modified = NOW()");
if let Some(v) = p.input_price {
qb.push(", input_price = ").push_bind(v);
}
if let Some(v) = p.output_price {
qb.push(", output_price = ").push_bind(v);
}
if let Some(v) = p.fixed_price {
qb.push(", fixed_price = ").push_bind(v);
}
if let Some(v) = p.unit_size {
qb.push(", unit_size = ").push_bind(v);
}
if let Some(v) = p.status {
qb.push(", status = ").push_bind(v);
}
if let Some(v) = p.remark {
qb.push(", remark = ").push_bind(v);
}
qb.push(" WHERE id = ")
.push_bind(id as i64)
.push(" AND deleted_at IS NULL");
qb.build().execute(pool).await.context(SqlxSnafu)?;
Ok(())
}
async fn delete_by_id(&self, pool: &Pool<Postgres>, id: u64) -> Result<()> {
sqlx::query(
r#"UPDATE token_prices SET deleted_at = NOW() WHERE id = $1 AND deleted_at IS NULL"#,
)
.bind(id as i64)
.execute(pool)
.await
.context(SqlxSnafu)?;
Ok(())
}
async fn count(&self, pool: &Pool<Postgres>, params: &ModelListParams) -> Result<i64> {
let mut qb: QueryBuilder<Postgres> = QueryBuilder::new("SELECT COUNT(*) FROM token_prices");
self.push_conditions(&mut qb, params)?;
let row: (i64,) = qb
.build_query_as()
.fetch_one(pool)
.await
.context(SqlxSnafu)?;
Ok(row.0)
}
async fn list(
&self,
pool: &Pool<Postgres>,
params: &ModelListParams,
) -> Result<Vec<Self::Output>> {
let mut qb: QueryBuilder<Postgres> = QueryBuilder::new("SELECT * FROM token_prices");
self.push_conditions(&mut qb, params)?;
params.push_pagination(&mut qb);
let rows = qb
.build_query_as::<TokenPriceSchema>()
.fetch_all(pool)
.await
.context(SqlxSnafu)?;
Ok(rows.into_iter().map(Into::into).collect())
}
fn push_filter_conditions<'args>(
&self,
qb: &mut QueryBuilder<'args, Postgres>,
filters: &HashMap<String, String>,
) -> Result<()> {
if let Some(service) = filters.get("service") {
qb.push(" AND service = ").push_bind(service.clone());
}
if let Some(status) = filters.get("status") {
if let Ok(v) = status.parse::<i16>() {
qb.push(" AND status = ").push_bind(v);
}
}
Ok(())
}
}