#![allow(dead_code)]
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::models::MessageRole;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Modality {
Text,
Image,
Video,
Audio,
PointCloud,
Action,
Sensor,
Depth,
Segmentation,
BoundingBox,
Pose,
Trajectory,
}
impl std::fmt::Display for Modality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Modality::Text => write!(f, "text"),
Modality::Image => write!(f, "image"),
Modality::Video => write!(f, "video"),
Modality::Audio => write!(f, "audio"),
Modality::PointCloud => write!(f, "point_cloud"),
Modality::Action => write!(f, "action"),
Modality::Sensor => write!(f, "sensor"),
Modality::Depth => write!(f, "depth"),
Modality::Segmentation => write!(f, "segmentation"),
Modality::BoundingBox => write!(f, "bounding_box"),
Modality::Pose => write!(f, "pose"),
Modality::Trajectory => write!(f, "trajectory"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum ModelCategory {
LLM,
VLM,
VLA,
ALM,
VALM,
Multimodal,
Embodied,
}
impl std::fmt::Display for ModelCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelCategory::LLM => write!(f, "LLM"),
ModelCategory::VLM => write!(f, "VLM"),
ModelCategory::VLA => write!(f, "VLA"),
ModelCategory::ALM => write!(f, "ALM"),
ModelCategory::VALM => write!(f, "VALM"),
ModelCategory::Multimodal => write!(f, "Multimodal"),
ModelCategory::Embodied => write!(f, "Embodied"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModalityCapabilities {
pub category: ModelCategory,
pub input_modalities: Vec<Modality>,
pub output_modalities: Vec<Modality>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_image_resolution: Option<(u32, u32)>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_video_duration: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_audio_duration: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_images_per_request: Option<u32>,
#[serde(default)]
pub supported_image_formats: Vec<ImageFormat>,
#[serde(default)]
pub supports_streaming: bool,
#[serde(default)]
pub supports_interleaved: bool,
}
impl Default for ModalityCapabilities {
fn default() -> Self {
Self {
category: ModelCategory::LLM,
input_modalities: vec![Modality::Text],
output_modalities: vec![Modality::Text],
max_image_resolution: None,
max_video_duration: None,
max_audio_duration: None,
max_images_per_request: None,
supported_image_formats: vec![],
supports_streaming: false,
supports_interleaved: false,
}
}
}
impl ModalityCapabilities {
pub fn llm() -> Self {
Self {
category: ModelCategory::LLM,
input_modalities: vec![Modality::Text],
output_modalities: vec![Modality::Text],
supports_streaming: true,
..Default::default()
}
}
pub fn vlm() -> Self {
Self {
category: ModelCategory::VLM,
input_modalities: vec![Modality::Text, Modality::Image],
output_modalities: vec![Modality::Text],
max_image_resolution: Some((4096, 4096)),
max_images_per_request: Some(20),
supported_image_formats: vec![
ImageFormat::Png,
ImageFormat::Jpeg,
ImageFormat::Webp,
ImageFormat::Gif,
],
supports_streaming: true,
supports_interleaved: true,
..Default::default()
}
}
pub fn vla() -> Self {
Self {
category: ModelCategory::VLA,
input_modalities: vec![
Modality::Text,
Modality::Image,
Modality::Sensor,
Modality::Depth,
],
output_modalities: vec![Modality::Text, Modality::Action, Modality::Trajectory],
max_image_resolution: Some((1024, 1024)),
max_images_per_request: Some(10),
supported_image_formats: vec![ImageFormat::Png, ImageFormat::Jpeg],
supports_streaming: true,
supports_interleaved: true,
..Default::default()
}
}
pub fn multimodal() -> Self {
Self {
category: ModelCategory::Multimodal,
input_modalities: vec![
Modality::Text,
Modality::Image,
Modality::Audio,
Modality::Video,
],
output_modalities: vec![Modality::Text, Modality::Image, Modality::Audio],
max_image_resolution: Some((4096, 4096)),
max_video_duration: Some(3600),
max_audio_duration: Some(3600),
max_images_per_request: Some(50),
supported_image_formats: vec![
ImageFormat::Png,
ImageFormat::Jpeg,
ImageFormat::Webp,
ImageFormat::Gif,
],
supports_streaming: true,
supports_interleaved: true,
}
}
pub fn embodied() -> Self {
Self {
category: ModelCategory::Embodied,
input_modalities: vec![
Modality::Text,
Modality::Image,
Modality::Depth,
Modality::PointCloud,
Modality::Sensor,
Modality::Pose,
],
output_modalities: vec![
Modality::Text,
Modality::Action,
Modality::Trajectory,
Modality::Pose,
],
max_image_resolution: Some((1280, 720)),
max_images_per_request: Some(8),
supported_image_formats: vec![ImageFormat::Png, ImageFormat::Jpeg],
supports_streaming: true,
supports_interleaved: true,
..Default::default()
}
}
pub fn supports_input(&self, modality: Modality) -> bool {
self.input_modalities.contains(&modality)
}
pub fn supports_output(&self, modality: Modality) -> bool {
self.output_modalities.contains(&modality)
}
pub fn supports_vision(&self) -> bool {
self.supports_input(Modality::Image) || self.supports_input(Modality::Video)
}
pub fn supports_actions(&self) -> bool {
self.supports_output(Modality::Action) || self.supports_output(Modality::Trajectory)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageFormat {
Png,
Jpeg,
Webp,
Gif,
Bmp,
Tiff,
Heic,
}
impl std::fmt::Display for ImageFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImageFormat::Png => write!(f, "png"),
ImageFormat::Jpeg => write!(f, "jpeg"),
ImageFormat::Webp => write!(f, "webp"),
ImageFormat::Gif => write!(f, "gif"),
ImageFormat::Bmp => write!(f, "bmp"),
ImageFormat::Tiff => write!(f, "tiff"),
ImageFormat::Heic => write!(f, "heic"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageContent {
pub data: ImageData,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub format: Option<ImageFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub alt_text: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub regions: Vec<BoundingBoxRegion>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ImageData {
Base64 {
#[serde(rename = "base64")]
data: String,
media_type: String,
},
Url {
url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
detail: Option<ImageDetail>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ImageDetail {
Low,
High,
Auto,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundingBoxRegion {
pub label: String,
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub confidence: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoContent {
pub data: VideoData,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub start_time: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub end_time: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fps: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum VideoData {
Url { url: String },
Frames { frames: Vec<ImageContent> },
Base64 { base64: String, media_type: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioContent {
pub data: AudioData,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sample_rate: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub transcription: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AudioData {
Url { url: String },
Base64 { base64: String, media_type: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionCommand {
pub action_type: ActionType,
pub parameters: ActionParameters,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub confidence: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timestamp: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ActionType {
Move,
Rotate,
Stop,
Grasp,
Release,
Push,
Pull,
Place,
Pick,
Open,
Close,
MoveArm,
MoveJoint,
Look,
Focus,
Custom,
Wait,
Sequence,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ActionParameters {
Movement {
#[serde(default)]
x: f64,
#[serde(default)]
y: f64,
#[serde(default)]
z: f64,
#[serde(default)]
is_velocity: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
frame: Option<String>,
},
Rotation {
#[serde(default, skip_serializing_if = "Option::is_none")]
roll: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pitch: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
yaw: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
quaternion: Option<[f64; 4]>,
},
Gripper {
aperture: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
force: Option<f64>,
},
JointPositions {
positions: Vec<f64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
joint_names: Vec<String>,
},
TargetPose {
position: [f64; 3],
orientation: [f64; 4],
},
Trajectory {
waypoints: Vec<Waypoint>,
#[serde(default, skip_serializing_if = "Option::is_none")]
interpolation: Option<String>,
},
Custom(serde_json::Value),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Waypoint {
pub position: [f64; 3],
#[serde(default, skip_serializing_if = "Option::is_none")]
pub orientation: Option<[f64; 4]>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub time: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gripper: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensorData {
pub sensor_type: SensorType,
pub values: SensorValues,
pub timestamp: DateTime<Utc>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub frame: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SensorType {
JointState,
Imu,
ForceTorque,
Depth,
Lidar,
Localization,
Tactile,
Odometry,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SensorValues {
JointState {
positions: Vec<f64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
velocities: Vec<f64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
efforts: Vec<f64>,
},
Imu {
linear_acceleration: [f64; 3],
angular_velocity: [f64; 3],
#[serde(default, skip_serializing_if = "Option::is_none")]
orientation: Option<[f64; 4]>,
},
ForceTorque { force: [f64; 3], torque: [f64; 3] },
Depth {
data: String,
width: u32,
height: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
encoding: Option<String>,
},
PointCloud {
points: Vec<[f64; 3]>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
colors: Vec<[u8; 3]>,
},
Pose {
position: [f64; 3],
orientation: [f64; 4],
},
Numeric(Vec<f64>),
Custom(serde_json::Value),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
Image(ImageContent),
Video(VideoContent),
Audio(AudioContent),
Sensor(SensorData),
Action(ActionCommand),
File {
url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
mime_type: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultimodalMessage {
pub role: MessageRole,
pub content: Vec<ContentPart>,
#[serde(default = "Utc::now")]
pub timestamp: DateTime<Utc>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, serde_json::Value>,
}
impl MultimodalMessage {
pub fn text(role: MessageRole, text: impl Into<String>) -> Self {
Self {
role,
content: vec![ContentPart::Text { text: text.into() }],
timestamp: Utc::now(),
metadata: HashMap::new(),
}
}
pub fn with_image(role: MessageRole, text: impl Into<String>, image: ImageContent) -> Self {
Self {
role,
content: vec![
ContentPart::Text { text: text.into() },
ContentPart::Image(image),
],
timestamp: Utc::now(),
metadata: HashMap::new(),
}
}
pub fn add_image(&mut self, image: ImageContent) {
self.content.push(ContentPart::Image(image));
}
pub fn add_sensor(&mut self, sensor: SensorData) {
self.content.push(ContentPart::Sensor(sensor));
}
pub fn add_action(&mut self, action: ActionCommand) {
self.content.push(ContentPart::Action(action));
}
pub fn text_content(&self) -> String {
self.content
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
pub fn images(&self) -> Vec<&ImageContent> {
self.content
.iter()
.filter_map(|part| match part {
ContentPart::Image(img) => Some(img),
_ => None,
})
.collect()
}
pub fn actions(&self) -> Vec<&ActionCommand> {
self.content
.iter()
.filter_map(|part| match part {
ContentPart::Action(action) => Some(action),
_ => None,
})
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultimodalModel {
pub id: String,
pub name: String,
pub provider: String,
pub category: ModelCategory,
pub capabilities: ModalityCapabilities,
pub max_context: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub version: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub release_date: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pricing: Option<ModelPricing>,
#[serde(default)]
pub available: bool,
#[serde(default)]
pub local: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_per_million: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_per_million: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub per_image: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub per_video_minute: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub per_audio_minute: Option<f64>,
#[serde(default = "default_currency")]
pub currency: String,
}
fn default_currency() -> String {
"USD".to_string()
}
pub fn vlm_models() -> Vec<MultimodalModel> {
vec![
MultimodalModel {
id: "gpt-4o".to_string(),
name: "GPT-4o".to_string(),
provider: "OpenAI".to_string(),
category: ModelCategory::Multimodal,
capabilities: ModalityCapabilities::multimodal(),
max_context: 128000,
version: Some("2024-11-20".to_string()),
release_date: Some("2024-05-13".to_string()),
description: Some("Most capable GPT-4 with vision, audio, and text".to_string()),
pricing: Some(ModelPricing {
input_per_million: Some(2.50),
output_per_million: Some(10.00),
per_image: None,
per_video_minute: None,
per_audio_minute: None,
currency: "USD".to_string(),
}),
available: true,
local: false,
},
MultimodalModel {
id: "gpt-4o-mini".to_string(),
name: "GPT-4o Mini".to_string(),
provider: "OpenAI".to_string(),
category: ModelCategory::VLM,
capabilities: ModalityCapabilities::vlm(),
max_context: 128000,
version: Some("2024-07-18".to_string()),
release_date: Some("2024-07-18".to_string()),
description: Some("Affordable small model with vision capabilities".to_string()),
pricing: Some(ModelPricing {
input_per_million: Some(0.15),
output_per_million: Some(0.60),
per_image: None,
per_video_minute: None,
per_audio_minute: None,
currency: "USD".to_string(),
}),
available: true,
local: false,
},
MultimodalModel {
id: "gemini-2.0-flash".to_string(),
name: "Gemini 2.0 Flash".to_string(),
provider: "Google".to_string(),
category: ModelCategory::Multimodal,
capabilities: ModalityCapabilities::multimodal(),
max_context: 1000000,
version: Some("2.0".to_string()),
release_date: Some("2024-12-11".to_string()),
description: Some("Fastest Gemini with native multimodal generation".to_string()),
pricing: Some(ModelPricing {
input_per_million: Some(0.075),
output_per_million: Some(0.30),
per_image: None,
per_video_minute: None,
per_audio_minute: None,
currency: "USD".to_string(),
}),
available: true,
local: false,
},
MultimodalModel {
id: "gemini-1.5-pro".to_string(),
name: "Gemini 1.5 Pro".to_string(),
provider: "Google".to_string(),
category: ModelCategory::Multimodal,
capabilities: ModalityCapabilities::multimodal(),
max_context: 2000000,
version: Some("1.5".to_string()),
release_date: Some("2024-02-15".to_string()),
description: Some("2M context window with video understanding".to_string()),
pricing: Some(ModelPricing {
input_per_million: Some(1.25),
output_per_million: Some(5.00),
per_image: None,
per_video_minute: None,
per_audio_minute: None,
currency: "USD".to_string(),
}),
available: true,
local: false,
},
MultimodalModel {
id: "claude-3-5-sonnet".to_string(),
name: "Claude 3.5 Sonnet".to_string(),
provider: "Anthropic".to_string(),
category: ModelCategory::VLM,
capabilities: ModalityCapabilities::vlm(),
max_context: 200000,
version: Some("20241022".to_string()),
release_date: Some("2024-10-22".to_string()),
description: Some("Best overall Claude with strong vision".to_string()),
pricing: Some(ModelPricing {
input_per_million: Some(3.00),
output_per_million: Some(15.00),
per_image: None,
per_video_minute: None,
per_audio_minute: None,
currency: "USD".to_string(),
}),
available: true,
local: false,
},
MultimodalModel {
id: "llava-1.6".to_string(),
name: "LLaVA 1.6".to_string(),
provider: "Open Source".to_string(),
category: ModelCategory::VLM,
capabilities: ModalityCapabilities::vlm(),
max_context: 4096,
version: Some("1.6".to_string()),
release_date: Some("2024-01-30".to_string()),
description: Some("Open-source vision-language model".to_string()),
pricing: None,
available: true,
local: true,
},
MultimodalModel {
id: "qwen2-vl".to_string(),
name: "Qwen2-VL".to_string(),
provider: "Alibaba".to_string(),
category: ModelCategory::VLM,
capabilities: {
let mut caps = ModalityCapabilities::vlm();
caps.input_modalities.push(Modality::Video);
caps
},
max_context: 32768,
version: Some("2.0".to_string()),
release_date: Some("2024-08-29".to_string()),
description: Some("Strong open VLM with video understanding".to_string()),
pricing: None,
available: true,
local: true,
},
MultimodalModel {
id: "pixtral-12b".to_string(),
name: "Pixtral 12B".to_string(),
provider: "Mistral".to_string(),
category: ModelCategory::VLM,
capabilities: ModalityCapabilities::vlm(),
max_context: 128000,
version: Some("1.0".to_string()),
release_date: Some("2024-09-11".to_string()),
description: Some("Mistral's vision model, runs locally".to_string()),
pricing: None,
available: true,
local: true,
},
]
}
pub fn vla_models() -> Vec<MultimodalModel> {
vec![
MultimodalModel {
id: "rt-2".to_string(),
name: "RT-2".to_string(),
provider: "Google DeepMind".to_string(),
category: ModelCategory::VLA,
capabilities: ModalityCapabilities::vla(),
max_context: 4096,
version: Some("2.0".to_string()),
release_date: Some("2023-07-28".to_string()),
description: Some("Robotics Transformer 2 - vision-language-action model".to_string()),
pricing: None,
available: false,
local: false,
},
MultimodalModel {
id: "rt-x".to_string(),
name: "RT-X".to_string(),
provider: "Open X-Embodiment".to_string(),
category: ModelCategory::VLA,
capabilities: ModalityCapabilities::vla(),
max_context: 4096,
version: Some("1.0".to_string()),
release_date: Some("2023-10-05".to_string()),
description: Some("Cross-embodiment robotics foundation model".to_string()),
pricing: None,
available: true,
local: true,
},
MultimodalModel {
id: "octo".to_string(),
name: "Octo".to_string(),
provider: "Berkeley AI Research".to_string(),
category: ModelCategory::VLA,
capabilities: ModalityCapabilities::vla(),
max_context: 2048,
version: Some("1.0".to_string()),
release_date: Some("2024-05-10".to_string()),
description: Some("Generalist robot policy from Open X-Embodiment".to_string()),
pricing: None,
available: true,
local: true,
},
MultimodalModel {
id: "openvla".to_string(),
name: "OpenVLA".to_string(),
provider: "Stanford/Berkeley".to_string(),
category: ModelCategory::VLA,
capabilities: ModalityCapabilities::vla(),
max_context: 4096,
version: Some("7B".to_string()),
release_date: Some("2024-06-13".to_string()),
description: Some("Open-source 7B parameter VLA model".to_string()),
pricing: None,
available: true,
local: true,
},
MultimodalModel {
id: "palm-e".to_string(),
name: "PaLM-E".to_string(),
provider: "Google".to_string(),
category: ModelCategory::Embodied,
capabilities: ModalityCapabilities::embodied(),
max_context: 8192,
version: Some("562B".to_string()),
release_date: Some("2023-03-06".to_string()),
description: Some("Embodied multimodal language model".to_string()),
pricing: None,
available: false,
local: false,
},
MultimodalModel {
id: "gr-1".to_string(),
name: "GR-1".to_string(),
provider: "Fourier Intelligence".to_string(),
category: ModelCategory::VLA,
capabilities: ModalityCapabilities::vla(),
max_context: 2048,
version: Some("1.0".to_string()),
release_date: Some("2024-03-18".to_string()),
description: Some("VLA for humanoid robot manipulation".to_string()),
pricing: None,
available: false,
local: false,
},
MultimodalModel {
id: "pi0".to_string(),
name: "Pi-Zero".to_string(),
provider: "Physical Intelligence".to_string(),
category: ModelCategory::VLA,
capabilities: ModalityCapabilities::vla(),
max_context: 4096,
version: Some("1.0".to_string()),
release_date: Some("2024-10-31".to_string()),
description: Some("General-purpose robot foundation model".to_string()),
pricing: None,
available: false,
local: false,
},
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_modality_display() {
assert_eq!(format!("{}", Modality::Text), "text");
assert_eq!(format!("{}", Modality::Image), "image");
assert_eq!(format!("{}", Modality::Action), "action");
}
#[test]
fn test_model_category_display() {
assert_eq!(format!("{}", ModelCategory::LLM), "LLM");
assert_eq!(format!("{}", ModelCategory::VLM), "VLM");
assert_eq!(format!("{}", ModelCategory::VLA), "VLA");
}
#[test]
fn test_vlm_capabilities() {
let caps = ModalityCapabilities::vlm();
assert!(caps.supports_input(Modality::Text));
assert!(caps.supports_input(Modality::Image));
assert!(!caps.supports_input(Modality::Action));
assert!(caps.supports_vision());
assert!(!caps.supports_actions());
}
#[test]
fn test_vla_capabilities() {
let caps = ModalityCapabilities::vla();
assert!(caps.supports_input(Modality::Text));
assert!(caps.supports_input(Modality::Image));
assert!(caps.supports_input(Modality::Sensor));
assert!(caps.supports_output(Modality::Action));
assert!(caps.supports_output(Modality::Trajectory));
assert!(caps.supports_vision());
assert!(caps.supports_actions());
}
#[test]
fn test_multimodal_message() {
let mut msg = MultimodalMessage::text(MessageRole::User, "What's in this image?");
msg.add_image(ImageContent {
data: ImageData::Url {
url: "https://example.com/image.jpg".to_string(),
detail: Some(ImageDetail::Auto),
},
format: Some(ImageFormat::Jpeg),
alt_text: Some("Test image".to_string()),
regions: vec![],
});
assert_eq!(msg.images().len(), 1);
assert_eq!(msg.text_content(), "What's in this image?");
}
#[test]
fn test_action_command() {
let action = ActionCommand {
action_type: ActionType::Grasp,
parameters: ActionParameters::Gripper {
aperture: 0.5,
force: Some(10.0),
},
confidence: Some(0.95),
timestamp: None,
duration_ms: Some(500),
};
assert_eq!(action.action_type, ActionType::Grasp);
}
#[test]
fn test_vlm_models_registry() {
let models = vlm_models();
assert!(!models.is_empty());
let gpt4o = models.iter().find(|m| m.id == "gpt-4o").unwrap();
assert_eq!(gpt4o.category, ModelCategory::Multimodal);
assert!(gpt4o.available);
}
#[test]
fn test_vla_models_registry() {
let models = vla_models();
assert!(!models.is_empty());
let openvla = models.iter().find(|m| m.id == "openvla").unwrap();
assert_eq!(openvla.category, ModelCategory::VLA);
assert!(openvla.local);
}
}