from sparkless.testing import get_imports
import random
from datetime import datetime, timedelta
import time
def test_quickstart(spark):
imports = get_imports()
F = imports.F
print("Testing Quickstart Tutorial...")
app_name = (
spark.appName
if hasattr(spark, "appName")
else getattr(spark, "app_name", "test")
)
print(f"✅ Session created: {app_name}")
data = [
{
"order_id": 1,
"customer": "Alice",
"product": "Laptop",
"quantity": 1,
"price": 1200,
},
{
"order_id": 2,
"customer": "Bob",
"product": "Mouse",
"quantity": 2,
"price": 25,
},
{
"order_id": 3,
"customer": "Alice",
"product": "Keyboard",
"quantity": 1,
"price": 100,
},
{
"order_id": 4,
"customer": "Charlie",
"product": "Monitor",
"quantity": 2,
"price": 300,
},
{
"order_id": 5,
"customer": "Bob",
"product": "Laptop",
"quantity": 1,
"price": 1200,
},
]
df = spark.createDataFrame(data)
assert df.count() == 5
print(f"✅ Created DataFrame with {df.count()} orders")
df_with_total = df.withColumn("total", F.col("quantity") * F.col("price"))
assert df_with_total.count() == 5
high_value_orders = df_with_total.filter(F.col("total") > 500)
assert high_value_orders.count() == 3
print(f"✅ Filter works: {high_value_orders.count()} high-value orders")
customer_revenue = (
df_with_total.groupBy("customer")
.agg(
F.sum("total").alias("total_revenue"),
F.count("order_id").alias("order_count"),
)
.orderBy(F.desc("total_revenue"))
)
assert customer_revenue.count() == 3
print(f"✅ Aggregations work: {customer_revenue.count()} customers")
df.createOrReplaceTempView("orders")
result = spark.sql("""
SELECT
customer,
COUNT(*) as order_count,
SUM(quantity * price) as total_spent,
AVG(quantity * price) as avg_order_value
FROM orders
GROUP BY customer
HAVING SUM(quantity * price) > 100
ORDER BY total_spent DESC
""")
assert result.count() >= 1
print(f"✅ SQL works: {result.count()} results")
customers = ["Alice", "Bob", "Charlie", "Diana", "Eve"]
products = ["Laptop", "Mouse", "Keyboard", "Monitor", "Headphones"]
start_date = datetime(2024, 1, 1)
large_data = [
{
"order_id": i,
"customer": random.choice(customers),
"product": random.choice(products),
"quantity": random.randint(1, 5),
"price": random.randint(10, 1500),
"order_date": (start_date + timedelta(days=random.randint(0, 90))).strftime(
"%Y-%m-%d"
),
}
for i in range(100)
]
large_df = spark.createDataFrame(large_data)
assert large_df.count() == 100
print(f"✅ Large dataset created: {large_df.count()} orders")
start_time = time.time()
result = (
large_df.withColumn("total", F.col("quantity") * F.col("price"))
.filter(F.col("total") > 100)
.groupBy("customer", "product")
.agg(F.sum("total").alias("revenue"))
.orderBy(F.desc("revenue"))
.collect()
)
elapsed = time.time() - start_time
print(f"✅ Performance test: {elapsed:.4f} seconds for {len(result)} results")
print("✅ Quickstart tutorial: ALL TESTS PASSED\n")
def test_dataframe_operations(spark):
imports = get_imports()
F = imports.F
print("Testing DataFrame Operations Tutorial...")
employees = [
{"emp_id": 1, "name": "Alice", "dept_id": 10, "salary": 80000, "city": "NYC"},
{"emp_id": 2, "name": "Bob", "dept_id": 20, "salary": 75000, "city": "LA"},
{"emp_id": 3, "name": "Charlie", "dept_id": 10, "salary": 90000, "city": "NYC"},
{
"emp_id": 4,
"name": "Diana",
"dept_id": 30,
"salary": 85000,
"city": "Chicago",
},
{"emp_id": 5, "name": "Eve", "dept_id": 20, "salary": 95000, "city": "LA"},
]
departments = [
{"dept_id": 10, "dept_name": "Engineering", "budget": 500000},
{"dept_id": 20, "dept_name": "Sales", "budget": 300000},
{"dept_id": 30, "dept_name": "Marketing", "budget": 200000},
]
emp_df = spark.createDataFrame(employees)
dept_df = spark.createDataFrame(departments)
result = emp_df.select("name", "salary")
assert result.count() == 5
print("✅ Select works")
high_earners = emp_df.filter(F.col("salary") > 80000)
assert high_earners.count() == 3
print("✅ Filter works")
joined = emp_df.join(dept_df, "dept_id", "inner")
assert joined.count() == 5
print("✅ Joins work")
enriched = emp_df.withColumn("salary_k", F.col("salary") / 1000)
assert "salary_k" in enriched.columns
print("✅ WithColumn works")
sorted_df = emp_df.orderBy(F.desc("salary"))
first_row = sorted_df.limit(1).collect()[0]
assert first_row["name"] == "Eve"
print("✅ OrderBy works")
cities = emp_df.select("city").distinct()
city_count = cities.count()
print(f" Distinct cities: {city_count}")
assert city_count >= 1 print("✅ Distinct works")
print("✅ DataFrame Operations tutorial: ALL TESTS PASSED\n")